@@ -50,61 +50,19 @@ function _eval_hessian_inner(
5050 @assert length (ex. hess_I) == 0
5151 return 0
5252 end
53- T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
5453 Coloring. prepare_seed_matrix! (ex. seed_matrix, ex. rinfo)
55- local_to_global_idx = ex. rinfo. local_indices
56- input_ϵ_raw, output_ϵ_raw = d. input_ϵ, d. output_ϵ
57- input_ϵ = _reinterpret_unsafe (T, input_ϵ_raw)
58- output_ϵ = _reinterpret_unsafe (T, output_ϵ_raw)
5954 # Compute hessian-vector products
6055 num_products = size (ex. seed_matrix, 2 ) # number of hessian-vector products
6156 num_chunks = div (num_products, CHUNK)
62- @assert size (ex. seed_matrix, 1 ) == length (local_to_global_idx)
63- for k in 1 : CHUNK: (CHUNK* num_chunks)
64- for r in 1 : length (local_to_global_idx)
65- # set up directional derivatives
66- @inbounds idx = local_to_global_idx[r]
67- # load up ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] into input_ϵ
68- for s in 1 : CHUNK
69- input_ϵ_raw[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, k+ s- 1 ]
70- end
71- @inbounds output_ϵ[idx] = zero (T)
72- end
73- _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, T)
74- # collect directional derivatives
75- for r in 1 : length (local_to_global_idx)
76- idx = local_to_global_idx[r]
77- # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+CHUNK-1]
78- for s in 1 : CHUNK
79- ex. seed_matrix[r, k+ s- 1 ] = output_ϵ_raw[(idx- 1 )* CHUNK+ s]
80- end
81- @inbounds input_ϵ[idx] = zero (T)
82- end
57+ @assert size (ex. seed_matrix, 1 ) == length (ex. rinfo. local_indices)
58+ for offset in 1 : CHUNK: (CHUNK* num_chunks)
59+ _eval_hessian_chunk (d, ex, offset, CHUNK, Val (CHUNK))
8360 end
8461 # leftover chunk
8562 remaining = num_products - CHUNK * num_chunks
8663 if remaining > 0
87- k = CHUNK * num_chunks + 1
88- for r in 1 : length (local_to_global_idx)
89- # set up directional derivatives
90- @inbounds idx = local_to_global_idx[r]
91- # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
92- for s in 1 : remaining
93- # leave junk in the unused components
94- input_ϵ_raw[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, k+ s- 1 ]
95- end
96- @inbounds output_ϵ[idx] = zero (T)
97- end
98- _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, T)
99- # collect directional derivatives
100- for r in 1 : length (local_to_global_idx)
101- idx = local_to_global_idx[r]
102- # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
103- for s in 1 : remaining
104- ex. seed_matrix[r, k+ s- 1 ] = output_ϵ_raw[(idx- 1 )* CHUNK+ s]
105- end
106- @inbounds input_ϵ[idx] = zero (T)
107- end
64+ offset = CHUNK * num_chunks + 1
65+ _eval_hessian_chunk (d, ex, offset, remaining, Val (CHUNK))
10866 end
10967 want, got = nzcount + length (ex. hess_I), length (H)
11068 if want > got
@@ -127,7 +85,40 @@ function _eval_hessian_inner(
12785 return length (ex. hess_I)
12886end
12987
130- function _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, :: Type{T} ) where {T}
88+ function _eval_hessian_chunk (
89+ d:: NLPEvaluator ,
90+ ex:: _FunctionStorage ,
91+ offset:: Int ,
92+ chunk:: Int ,
93+ :: Val{CHUNK} ,
94+ ) where {CHUNK}
95+ for r in eachindex (ex. rinfo. local_indices)
96+ # set up directional derivatives
97+ @inbounds idx = ex. rinfo. local_indices[r]
98+ # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
99+ for s in 1 : chunk
100+ # If `chunk < CHUNK`, leaves junk in the unused components
101+ d. input_ϵ[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, offset+ s- 1 ]
102+ end
103+ end
104+ _hessian_slice_inner (d, ex, Val (CHUNK))
105+ fill! (d. input_ϵ, 0.0 )
106+ # collect directional derivatives
107+ for r in eachindex (ex. rinfo. local_indices)
108+ @inbounds idx = ex. rinfo. local_indices[r]
109+ # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
110+ for s in 1 : chunk
111+ ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* CHUNK+ s]
112+ end
113+ end
114+ return
115+ end
116+
117+ function _hessian_slice_inner (d, ex, :: Val{CHUNK} ) where {CHUNK}
118+ T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
119+ input_ϵ = _reinterpret_unsafe (T, d. input_ϵ)
120+ fill! (d. output_ϵ, 0.0 )
121+ output_ϵ = _reinterpret_unsafe (T, d. output_ϵ)
131122 subexpr_forward_values_ϵ =
132123 _reinterpret_unsafe (T, d. subexpression_forward_values_ϵ)
133124 for i in ex. dependent_subexpressions
0 commit comments