Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions src/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Its basic purpose is to make it easier to modify some parameters while running s
REPL. Additionally, this is used in tests to reduce the computational burden for CI while still
providing examples with sensible default values for users.

In case of nested calls to `trixi_include` inside `elixir`, the keyword arguments are also
passed to the nested calls. This allows to override assignments in nested files as well.

Before replacing assignments in `elixir`, the keyword argument `maxiters` is inserted
into calls to `solve` with it's default value used in the SciML ecosystem
for ODEs, see the "Miscellaneous" section of the
Expand All @@ -34,22 +37,30 @@ julia> redirect_stdout(devnull) do
0.1
```
"""
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...)
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString;
_trixi_include_recursive::Bool = false, kwargs...)
# Check that all kwargs exist as assignments
code = read(elixir, String)
expr = Meta.parse("begin \n$code \nend")
expr = insert_maxiters(expr)

for (key, val) in kwargs
# This will throw an error when `key` is not found
find_assignment(expr, key)
# Validate that all kwargs exist as assignments (with warning for recursive cases)
# Skip for recursive calls because all kwargs are passed to all nested calls,
# some of which may not use all kwargs.
if !_trixi_include_recursive
validate_assignments(expr, kwargs, elixir)
end

# Print information on potential wait time only in non-parallel case
if !mpi_isparallel()
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
end
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex); kwargs...)),

# Add kwarg `_trixi_include_recursive`, which will be added to nested calls
# to `trixi_include` to avoid the validation above.
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex);
_trixi_include_recursive = true,
kwargs...)),
mod, elixir)
end

Expand Down Expand Up @@ -160,23 +171,88 @@ walkexpr(f, x) = f(x)

# Replace assignments to `key` in `expr` by `key = val` for all `(key,val)` in `kwargs`.
function replace_assignments(expr; kwargs...)
# replace explicit and keyword assignments
expr = walkexpr(expr) do x
if x isa Expr
# Replace explicit and keyword assignments
for (key, val) in kwargs
if (x.head === Symbol("=") || x.head === :kw) &&
x.args[1] === Symbol(key)
x.args[2] = :($val)
# dump(x)
end
end

# Handle `trixi_include` calls - add kwargs to them as well
if (!isempty(kwargs) && x.head === :call && length(x.args) >= 2 &&
(x.args[1] === :trixi_include ||
x.args[1] === :trixi_include_changeprecision))

# Check for existing kwargs (both direct :kw and bare symbols in :parameters)
existing_kwargs = Set{Symbol}()
for arg in x.args[2:end] # Skip function name
if arg isa Expr && arg.head === :kw
# Direct keyword argument like `x=5` in `f(x=5)`
push!(existing_kwargs, arg.args[1])
elseif arg isa Expr && arg.head === :parameters
# Keyword arguments grouped in `parameters`
# like `f(; x=5)` or `f(; x)`.
for nested_arg in arg.args
if nested_arg isa Symbol
# Bare symbol like `x` in `f(; x)`
push!(existing_kwargs, nested_arg)
elseif nested_arg isa Expr && nested_arg.head === :kw
# Keyword argument like `x=5` in `f(; x=5)`
push!(existing_kwargs, nested_arg.args[1])
end
end
end
end

# Add kwargs that don't already exist
for (key, val) in kwargs
if !(Symbol(key) in existing_kwargs)
push!(x.args, Expr(:kw, Symbol(key), val))
end
end
end
end
return x
end

return expr
end

# Validate that assignments exist as assignments, with a warning for recursive calls
function validate_assignments(expr, assignments, filename)
isempty(assignments) && return

found_assignments = Set{Symbol}()
has_nested_calls = false

walkexpr(expr) do x
if x isa Expr
if (x.head === Symbol("=") || x.head === :kw) && x.args[1] isa Symbol
push!(found_assignments, x.args[1])
elseif (x.head === :call && length(x.args) >= 2 &&
(x.args[1] === :trixi_include ||
x.args[1] === :trixi_include_changeprecision))
has_nested_calls = true
end
end
return x
end

missing_assignments = setdiff(Symbol.(keys(assignments)), found_assignments)
if !isempty(missing_assignments)
if has_nested_calls
@warn "assignments $missing_assignments not found in $filename, " *
"but nested trixi_include calls detected. They may be used in nested files."
else
throw(ArgumentError("assignments $missing_assignments not found in $filename"))
end
end
end

# Find a (keyword or common) assignment to `destination` in `expr`
# and return the assigned value.
function find_assignment(expr, destination)
Expand Down
138 changes: 135 additions & 3 deletions test/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
@trixi_test_nowarn trixi_include(path, x = 11)
@test Main.x == 11

@test_throws "assignment `y` not found in expression" trixi_include(@__MODULE__,
path,
y = 3)
@test_throws "assignments [:y] not found" trixi_include(@__MODULE__,
path, y = 3)
end
end

Expand Down Expand Up @@ -115,6 +114,139 @@
end
end
end

@trixi_testset "Recursive assignment overwriting" begin
# Test basic recursive kwargs passing
example1 = """
x = 1
y = 2
"""

example2 = """
z = 3
trixi_include(@__MODULE__, nested_path)
"""

mktemp() do path1, io1
write(io1, example1)
close(io1)

mktemp() do path2, io2
# Use raw string to allow backslashes in Windows paths
nested_code = replace(example2, "nested_path" => "raw\"$path1\"")
write(io2, nested_code)
close(io2)

# Test that kwargs are passed recursively
# Should warn about x,y not being in top file but allow due to nested calls
@test_warn "assignments" trixi_include(@__MODULE__, path2; x = 10, y = 20,
z = 30)
@test @isdefined x
@test @isdefined y
@test @isdefined z
@test x == 10 # Overridden from nested file
@test y == 20 # Overridden from nested file
@test z == 30 # Overridden from top file
end
end

# Test with existing kwargs in nested calls
example3 = """
a = 100
trixi_include(@__MODULE__, nested_path; a = 200)
"""

example4 = """
a = 1
b = 2
"""

mktemp() do path3, io3
write(io3, example4)
close(io3)

mktemp() do path4, io4
nested_code = replace(example3, "nested_path" => "raw\"$path3\"")
write(io4, nested_code)
close(io4)

# Test that top-level kwargs override existing nested kwargs
trixi_include(@__MODULE__, path4; a = 500, b = 600)
@test @isdefined a
@test @isdefined b
@test a == 500 # Top-level override wins over nested explicit kwarg
@test b == 600 # Passed through to nested file
end
end

# Test bare symbol syntax with recursion
example5 = """
x = 42
trixi_include(@__MODULE__, nested_path; x)
"""

example6 = """
x = 1
"""

mktemp() do path5, io5
write(io5, example6)
close(io5)

mktemp() do path6, io6
nested_code = replace(example5, "nested_path" => "raw\"$path5\"")
write(io6, nested_code)
close(io6)

# Test bare symbol with recursive override
@test_nowarn_mod trixi_include(@__MODULE__, path6; x = 999)
@test @isdefined x
@test x == 999 # Top-level override
end
end

# Test deep nesting (3 levels)
example7 = """
level1 = 1
"""

example8 = """
level2 = 2
trixi_include(@__MODULE__, level1_path)
"""

example9 = """
level3 = 3
trixi_include(@__MODULE__, level2_path; level2 = 22)
"""

mktemp() do path7, io7
write(io7, example7)
close(io7)

mktemp() do path8, io8
level2_code = replace(example8, "level1_path" => "raw\"$path7\"")
write(io8, level2_code)
close(io8)

mktemp() do path9, io9
level3_code = replace(example9, "level2_path" => "raw\"$path8\"")
write(io9, level3_code)
close(io9)

# Test 3-level deep recursive override
trixi_include(@__MODULE__, path9; level1 = 111,
level2 = 222, level3 = 333)
@test @isdefined level1
@test @isdefined level2
@test @isdefined level3
@test level1 == 111 # Passed through 3 levels
@test level2 == 222 # Top-level override wins over level3 explicit kwarg
@test level3 == 333 # Direct override
end
end
end
end
end

@trixi_testset "`trixi_include_changeprecision`" begin
Expand Down
Loading