Skip to content

Commit 8b903aa

Browse files
committed
Update
1 parent 7b19432 commit 8b903aa

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ end
116116

117117
function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
118118
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
119-
input_ϵ = _reinterpret_unsafe(T, d.input_ϵ)
120119
fill!(d.output_ϵ, 0.0)
121120
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
122121
subexpr_forward_values_ϵ =
@@ -168,11 +167,7 @@ end
168167
_forward_eval_ϵ(
169168
d::NLPEvaluator,
170169
ex::Union{_FunctionStorage,_SubexpressionStorage},
171-
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
172170
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
173-
x_values_ϵ,
174-
subexpression_values_ϵ,
175-
user_operators::Nonlinear.OperatorRegistry,
176171
) where {N,T}
177172
178173
Evaluate the directional derivatives of the expression tree in `ex`.
@@ -186,14 +181,12 @@ This assumes that `_reverse_model(d, x)` has already been called.
186181
function _forward_eval_ϵ(
187182
d::NLPEvaluator,
188183
ex::Union{_FunctionStorage,_SubexpressionStorage},
189-
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
190-
) where {N,T}
191-
P = ForwardDiff.Partials{N,T}
184+
partials_storage_ϵ::AbstractVector{P}
185+
) where {N,T,P<:ForwardDiff.Partials{N,T}}
192186
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
193187
x_values_ϵ = reinterpret(P, d.input_ϵ)
194188
subexpression_values_ϵ =
195189
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
196-
user_operators = d.data.operators
197190
@assert length(storage_ϵ) >= length(ex.nodes)
198191
@assert length(partials_storage_ϵ) >= length(ex.nodes)
199192
zero_ϵ = zero(P)
@@ -329,8 +322,8 @@ function _forward_eval_ϵ(
329322
n_children,
330323
)
331324
has_hessian = Nonlinear.eval_multivariate_hessian(
332-
user_operators,
333-
user_operators.multivariate_operators[node.index],
325+
d.data.operators,
326+
d.data.operators.multivariate_operators[node.index],
334327
H,
335328
f_input,
336329
)
@@ -356,7 +349,7 @@ function _forward_eval_ϵ(
356349
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
357350
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
358351
f′′ = Nonlinear.eval_univariate_hessian(
359-
user_operators,
352+
d.data.operators,
360353
node.index,
361354
ex.forward_storage[child_idx],
362355
)

0 commit comments

Comments
 (0)