Skip to content

Commit 7b19432

Browse files
committed
Remove arguments of _forward_eval_ϵ
1 parent 2bd236f commit 7b19432

File tree

2 files changed

+10
-36
lines changed

2 files changed

+10
-36
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,10 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
126126
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
127127
d,
128128
subexpr,
129-
_reinterpret_unsafe(T, d.storage_ϵ),
130129
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
131-
input_ϵ,
132-
subexpr_forward_values_ϵ,
133-
d.data.operators,
134130
)
135131
end
136-
_forward_eval_ϵ(
137-
d,
138-
ex,
139-
_reinterpret_unsafe(T, d.storage_ϵ),
140-
_reinterpret_unsafe(T, d.partials_storage_ϵ),
141-
input_ϵ,
142-
subexpr_forward_values_ϵ,
143-
d.data.operators,
144-
)
132+
_forward_eval_ϵ(d, ex, _reinterpret_unsafe(T, d.partials_storage_ϵ))
145133
# do a reverse pass
146134
subexpr_reverse_values_ϵ =
147135
_reinterpret_unsafe(T, d.subexpression_reverse_values_ϵ)
@@ -198,15 +186,17 @@ This assumes that `_reverse_model(d, x)` has already been called.
198186
function _forward_eval_ϵ(
199187
d::NLPEvaluator,
200188
ex::Union{_FunctionStorage,_SubexpressionStorage},
201-
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
202189
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
203-
x_values_ϵ,
204-
subexpression_values_ϵ,
205-
user_operators::Nonlinear.OperatorRegistry,
206190
) where {N,T}
191+
P = ForwardDiff.Partials{N,T}
192+
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
193+
x_values_ϵ = reinterpret(P, d.input_ϵ)
194+
subexpression_values_ϵ =
195+
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
196+
user_operators = d.data.operators
207197
@assert length(storage_ϵ) >= length(ex.nodes)
208198
@assert length(partials_storage_ϵ) >= length(ex.nodes)
209-
zero_ϵ = zero(ForwardDiff.Partials{N,T})
199+
zero_ϵ = zero(P)
210200
# ex.nodes is already in order such that parents always appear before children
211201
# so a backwards pass through ex.nodes is a forward pass through the tree
212202
children_arr = SparseArrays.rowvals(ex.adj)
@@ -348,7 +338,7 @@ function _forward_eval_ϵ(
348338
# multivariate functions.
349339
@assert has_hessian
350340
for col in 1:n_children
351-
dual = zero(ForwardDiff.Partials{N,T})
341+
dual = zero(P)
352342
for row in 1:n_children
353343
# Make sure we get the lower-triangular component.
354344
h = row >= col ? H[row, col] : H[col, row]

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
348348
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
349349
d,
350350
subexpr,
351-
reinterpret(T, d.storage_ϵ),
352351
reinterpret(T, subexpr.partials_storage_ϵ),
353-
input_ϵ,
354-
subexpr_forward_values_ϵ,
355-
d.data.operators,
356352
)
357353
end
358354
# we only need to do one reverse pass through the subexpressions as well
@@ -365,11 +361,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
365361
_forward_eval_ϵ(
366362
d,
367363
something(d.objective),
368-
reinterpret(T, d.storage_ϵ),
369364
reinterpret(T, d.partials_storage_ϵ),
370-
input_ϵ,
371-
subexpr_forward_values_ϵ,
372-
d.data.operators,
373365
)
374366
_reverse_eval_ϵ(
375367
output_ϵ,
@@ -383,15 +375,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
383375
)
384376
end
385377
for (i, con) in enumerate(d.constraints)
386-
_forward_eval_ϵ(
387-
d,
388-
con,
389-
reinterpret(T, d.storage_ϵ),
390-
reinterpret(T, d.partials_storage_ϵ),
391-
input_ϵ,
392-
subexpr_forward_values_ϵ,
393-
d.data.operators,
394-
)
378+
_forward_eval_ϵ(d, con, reinterpret(T, d.partials_storage_ϵ))
395379
_reverse_eval_ϵ(
396380
output_ϵ,
397381
con,

0 commit comments

Comments
 (0)