Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 112 additions & 9 deletions src/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# of `TrixiBase`. However, users will want to evaluate in the global scope of `Main` or something
# similar to manage dependencies on their own.
"""
trixi_include([mapexpr::Function=identity,] [mod::Module=Main,] elixir::AbstractString; kwargs...)
trixi_include([mapexpr::Function=identity,] [mod::Module=Main,] elixir::AbstractString;
enable_assignment_validation::Bool = true,
replace_assignments_recursive::Bool = false, kwargs...)

`include` the file `elixir` and evaluate its content in the global scope of module `mod`.
You can override specific assignments in `elixir` by supplying keyword arguments.
Expand All @@ -20,6 +22,16 @@ The optional first argument `mapexpr` can be used to transform the included code
it is evaluated: for each parsed expression `expr` in `elixir`, the `include` function
actually evaluates `mapexpr(expr)`. If it is omitted, `mapexpr` defaults to `identity`.

With `replace_assignments_recursive=true`, the keyword arguments are also passed
to nested calls of `trixi_include`. This allows to override assignments in nested files as well.

The keyword argument `enable_assignment_validation`, which is enabled by default,
can be used to enable or disable validation that all passed keyword arguments exist
as assignments in `elixir`. If `enable_assignment_validation` is `true` and
an assignment for a passed keyword argument is not found in `elixir`, an error is thrown.
If `replace_assignments_recursive` is `true` and `elixir` contains calls to `trixi_include`
itself, a warning is issued instead of an error.

# Examples

```@example
Expand All @@ -34,23 +46,40 @@ julia> redirect_stdout(devnull) do
0.1
```
"""
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...)
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString;
enable_assignment_validation::Bool = true,
replace_assignments_recursive::Bool = false, kwargs...)
# Check that all kwargs exist as assignments
code = read(elixir, String)
expr = Meta.parse("begin \n$code \nend")
expr = insert_maxiters(expr)

for (key, val) in kwargs
# This will throw an error when `key` is not found
find_assignment(expr, key)
# Validate that all kwargs exist as assignments (with warning for recursive cases).
# Skip for nested calls because all kwargs are passed to all nested calls,
# some of which may not use all kwargs.
if enable_assignment_validation
validate_assignments(expr, kwargs, elixir, replace_assignments_recursive)
end

# Print information on potential wait time only in non-parallel case
if !mpi_isparallel()
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
end
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex); kwargs...)),
mod, elixir)

if replace_assignments_recursive
# Add kwarg `enable_assignment_validation` to disable validation in nested
# `trixi_include` calls.
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex),
replace_assignments_recursive;
enable_assignment_validation = false,
replace_assignments_recursive = true,
kwargs...)),
mod, elixir)
else
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex);
kwargs...)),
mod, elixir)
end
end

function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
Expand Down Expand Up @@ -159,24 +188,98 @@ walkexpr(f, expr::Expr) = f(Expr(expr.head, (walkexpr(f, arg) for arg in expr.ar
walkexpr(f, x) = f(x)

# Replace assignments to `key` in `expr` by `key = val` for all `(key,val)` in `kwargs`.
function replace_assignments(expr; kwargs...)
# replace explicit and keyword assignments
function replace_assignments(expr, recursive = false; kwargs...)
expr = walkexpr(expr) do x
if x isa Expr
# Replace explicit and keyword assignments
for (key, val) in kwargs
if (x.head === Symbol("=") || x.head === :kw) &&
x.args[1] === Symbol(key)
x.args[2] = :($val)
# dump(x)
end
end

# If `recursive` is true:
# Handle `trixi_include` calls - add kwargs to them as well.
is_trixi_include = (x.head === :call && length(x.args) >= 2 &&
(x.args[1] === :trixi_include ||
x.args[1] === :trixi_include_changeprecision))
if !isempty(kwargs) && is_trixi_include && recursive

# Check for existing kwargs (both direct :kw and bare symbols in :parameters)
existing_kwargs = Set{Symbol}()
for arg in x.args[2:end] # Skip function name
if arg isa Expr && arg.head === :kw
# Direct keyword argument like `x=5` in `f(x=5)`
push!(existing_kwargs, arg.args[1])
elseif arg isa Expr && arg.head === :parameters
# Keyword arguments grouped in `parameters`
# like `f(; x=5)` or `f(; x)`.
for nested_arg in arg.args
if nested_arg isa Symbol
# Bare symbol like `x` in `f(; x)`
push!(existing_kwargs, nested_arg)
elseif nested_arg isa Expr && nested_arg.head === :kw
# Keyword argument like `x=5` in `f(; x=5)`
push!(existing_kwargs, nested_arg.args[1])
end
end
end
end

# Add kwargs that don't already exist.
# Note that existing keywords as assignment (`x=5`) don't need to be added
# again because they are replaced in the loop
# "Replace explicit and keyword assignments" above.
# Bare symbol like `x` in `f(; x)` must have been defined in the file
# before they are passed to `trixi_include`, so there must be an assignment
# `x = ...` in the file, which will also be replaced in the loop above.
for (key, val) in kwargs
if !(Symbol(key) in existing_kwargs)
push!(x.args, Expr(:kw, Symbol(key), val))
end
end
end
end
return x
end

return expr
end

# Validate that keyword arguments passed to `trixi_include` exist as assignments
# in the expression. Throw an error if they are not found or a warning for recursive calls.
function validate_assignments(expr, assignments, filename, replace_assignments_recursive)
isempty(assignments) && return

found_assignments = Set{Symbol}()
has_nested_calls = false

walkexpr(expr) do x
if x isa Expr
if (x.head === Symbol("=") || x.head === :kw) && x.args[1] isa Symbol
push!(found_assignments, x.args[1])
elseif (x.head === :call && length(x.args) >= 2 &&
(x.args[1] === :trixi_include ||
x.args[1] === :trixi_include_changeprecision))
has_nested_calls = true
end
end
return x
end

missing_assignments = setdiff(Symbol.(keys(assignments)), found_assignments)
if !isempty(missing_assignments)
if replace_assignments_recursive && has_nested_calls
@warn "assignments $missing_assignments not found in $filename, " *
"but nested trixi_include calls detected. They may be used in nested files."
else
throw(ArgumentError("assignments $missing_assignments not found in $filename"))
end
end
end

# Find a (keyword or common) assignment to `destination` in `expr`
# and return the assigned value.
function find_assignment(expr, destination)
Expand Down
157 changes: 154 additions & 3 deletions test/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
@trixi_test_nowarn trixi_include(path, x = 11)
@test Main.x == 11

@test_throws "assignment `y` not found in expression" trixi_include(@__MODULE__,
path,
y = 3)
@test_throws "assignments [:y] not found" trixi_include(@__MODULE__,
path, y = 3)
end
end

Expand Down Expand Up @@ -115,6 +114,158 @@
end
end
end

@trixi_testset "Recursive assignment overwriting" begin
# Test basic recursive kwargs passing
example1 = """
x = 1
y = 2
"""

example2 = """
z = 3
trixi_include(@__MODULE__, nested_path)
"""

mktemp() do path1, io1
write(io1, example1)
close(io1)

mktemp() do path2, io2
# Use raw string to allow backslashes in Windows paths
nested_code = replace(example2, "nested_path" => "raw\"$path1\"")
write(io2, nested_code)
close(io2)

# Test that kwargs are passed recursively
# Should warn about x,y not being in top file but allow due to nested calls
@test_warn "assignments" trixi_include(@__MODULE__, path2;
x = 10, y = 20, z = 30,
replace_assignments_recursive = true)
@test @isdefined x
@test @isdefined y
@test @isdefined z
@test x == 10 # Overridden from nested file
@test y == 20 # Overridden from nested file
@test z == 30 # Overridden from top file

# Test that kwargs are NOT passed recursively
@trixi_test_nowarn trixi_include(@__MODULE__, path2;
x = 10, y = 20, z = 30,
replace_assignments_recursive = false,
enable_assignment_validation = false)

@test x == 1 # Not overridden from nested file
@test y == 2 # Not overridden from nested file
@test z == 30 # Overridden from top file

# Without disabling validation, this should result in an error:
@test_throws "assignments [:x, :y] not found" trixi_include(@__MODULE__,
path2; x = 10,
y = 20, z = 30)
end
end

# Test with existing kwargs in nested calls
example3 = """
a = 100
trixi_include(@__MODULE__, nested_path; a = 200)
"""

example4 = """
a = 1
b = 2
"""

mktemp() do path3, io3
write(io3, example4)
close(io3)

mktemp() do path4, io4
nested_code = replace(example3, "nested_path" => "raw\"$path3\"")
write(io4, nested_code)
close(io4)

# Test that top-level kwargs override existing nested kwargs
trixi_include(@__MODULE__, path4; a = 500, b = 600,
replace_assignments_recursive = true)
@test @isdefined a
@test @isdefined b
@test a == 500 # Top-level override wins over nested explicit kwarg
@test b == 600 # Passed through to nested file
end
end

# Test bare symbol syntax with recursion
example5 = """
x = 42
trixi_include(@__MODULE__, nested_path; x)
"""

example6 = """
x = 1
"""

mktemp() do path5, io5
write(io5, example6)
close(io5)

mktemp() do path6, io6
nested_code = replace(example5, "nested_path" => "raw\"$path5\"")
write(io6, nested_code)
close(io6)

# Test bare symbol with recursive override
@trixi_test_nowarn trixi_include(@__MODULE__, path6; x = 999,
replace_assignments_recursive = true)
@test @isdefined x
@test x == 999 # Top-level override
end
end

# Test deep nesting (3 levels)
example7 = """
level1 = 1
"""

example8 = """
level2 = 2
trixi_include(@__MODULE__, level1_path)
"""

example9 = """
level3 = 3
trixi_include(@__MODULE__, level2_path; level2 = 22)
"""

mktemp() do path7, io7
write(io7, example7)
close(io7)

mktemp() do path8, io8
level2_code = replace(example8, "level1_path" => "raw\"$path7\"")
write(io8, level2_code)
close(io8)

mktemp() do path9, io9
level3_code = replace(example9, "level2_path" => "raw\"$path8\"")
write(io9, level3_code)
close(io9)

# Test 3-level deep recursive override
trixi_include(@__MODULE__, path9; level1 = 111,
level2 = 222, level3 = 333,
replace_assignments_recursive = true)
@test @isdefined level1
@test @isdefined level2
@test @isdefined level3
@test level1 == 111 # Passed through 3 levels
@test level2 == 222 # Top-level override wins over level3 explicit kwarg
@test level3 == 333 # Direct override
end
end
end
end
end

@trixi_testset "`trixi_include_changeprecision`" begin
Expand Down
Loading