Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/Nonlinear/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,26 @@
)
end

# We use a let block here for `expr` to create a local variable that does not
# persist in the scope of the module. All we care about is the _eval_univariate
# function that is eval'd as a result.
let
function _generate_eval_univariate()

Check warning on line 24 in src/Nonlinear/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/Nonlinear/operators.jl#L24

Added line #L24 was not covered by tests
exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
return :(return $(arg[1])(x), $(arg[2]))
end
@eval @inline function _eval_univariate(id, x::T) where {T}
$(_create_binary_switch(1:length(exprs), exprs))
return error("Invalid id for univariate operator: $id")
end
∇²f_exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)
@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}
$(_create_binary_switch(1:length(∇²f_exprs), ∇²f_exprs))
return error("Invalid id for univariate operator: $id")
end
return _create_binary_switch(1:length(exprs), exprs)
end

@eval @inline function _eval_univariate(id, x::T) where {T}
$(_generate_eval_univariate())
return error("Invalid id for univariate operator: $id")
end

function _generate_eval_univariate_2nd_deriv()

Check warning on line 36 in src/Nonlinear/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/Nonlinear/operators.jl#L36

Added line #L36 was not covered by tests
exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)
return _create_binary_switch(1:length(exprs), exprs)
end

@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}

Check warning on line 41 in src/Nonlinear/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/Nonlinear/operators.jl#L41

Added line #L41 was not covered by tests
$(_generate_eval_univariate_2nd_deriv())
return error("Invalid id for univariate operator: $id")
end

struct _UnivariateOperator{F,F′,F′′}
Expand Down
16 changes: 16 additions & 0 deletions test/Nonlinear/Nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,22 @@ function test_create_binary_switch()
),
)
@test MOI.Nonlinear._create_binary_switch(1:4, [:a, :b, :c, :d]) == target
# Just test that these functions don't error. We'll test their contents by
# evaluating the actual fuctions that are `@eval`ed.
MOI.Nonlinear._generate_eval_univariate()
MOI.Nonlinear._generate_eval_univariate_2nd_deriv()
return
end

function test_intercept_ForwardDiff_MethodError()
r = Nonlinear.OperatorRegistry()
f(x::Float64) = sin(x)^2
g(x) = x > 1 ? f(x) : zero(x)
Nonlinear.register_operator(r, :g, 1, g)
@test Nonlinear.eval_univariate_function(r, :g, 0.0) == 0.0
@test Nonlinear.eval_univariate_function(r, :g, 2.0) ≈ sin(2.0)^2
@test Nonlinear.eval_univariate_gradient(r, :g, 0.0) == 0.0
@test_throws ErrorException Nonlinear.eval_univariate_gradient(r, :g, 2.0)
return
end

Expand Down
Loading