@@ -6,7 +6,10 @@ struct SMCMixedModeSparseJacobianPrep{
66 BSr<: DI.BatchSizeSettings ,
77 P<: AbstractMatrix ,
88 C<: AbstractColoringResult{:nonsymmetric,:bidirectional} ,
9- M<: AbstractMatrix{<:Number} ,
9+ Mf<: AbstractMatrix{<:Number} ,
10+ Mr<: AbstractMatrix{<:Number} ,
11+ Sfp<: NTuple ,
12+ Srp<: NTuple ,
1013 Sf<: Vector{<:NTuple} ,
1114 Sr<: Vector{<:NTuple} ,
1215 Rf<: Vector{<:NTuple} ,
@@ -19,8 +22,10 @@ struct SMCMixedModeSparseJacobianPrep{
1922 batch_size_settings_reverse:: BSr
2023 sparsity:: P
2124 coloring_result:: C
22- compressed_matrix_forward:: M
23- compressed_matrix_reverse:: M
25+ compressed_matrix_forward:: Mf
26+ compressed_matrix_reverse:: Mr
27+ batched_seed_forward_prep:: Sfp
28+ batched_seed_reverse_prep:: Srp
2429 batched_seeds_forward:: Sf
2530 batched_seeds_reverse:: Sr
2631 batched_results_forward:: Rf
@@ -111,12 +116,24 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
111116 groups_forward = column_groups (coloring_result)
112117 groups_reverse = row_groups (coloring_result)
113118
119+ seed_forward_prep = DI. multibasis (x, eachindex (x))
120+ seed_reverse_prep = DI. multibasis (y, eachindex (y))
114121 seeds_forward = [DI. multibasis (x, eachindex (x)[group]) for group in groups_forward]
115122 seeds_reverse = [DI. multibasis (y, eachindex (y)[group]) for group in groups_reverse]
116123
117- compressed_matrix_forward = stack (_ -> vec (similar (y)), groups_forward; dims= 2 )
118- compressed_matrix_reverse = stack (_ -> vec (similar (x)), groups_reverse; dims= 1 )
124+ compressed_matrix_forward = if isempty (groups_forward)
125+ similar (vec (y), length (y), 0 )
126+ else
127+ stack (_ -> vec (similar (y)), groups_forward; dims= 2 )
128+ end
129+ compressed_matrix_reverse = if isempty (groups_reverse)
130+ similar (vec (x), 0 , length (x))
131+ else
132+ stack (_ -> vec (similar (x)), groups_reverse; dims= 1 )
133+ end
119134
135+ batched_seed_forward_prep = ntuple (b -> copy (seed_forward_prep), Val (Bf))
136+ batched_seed_reverse_prep = ntuple (b -> copy (seed_reverse_prep), Val (Br))
120137 batched_seeds_forward = [
121138 ntuple (b -> seeds_forward[1 + ((a - 1 ) * Bf + (b - 1 )) % Nf], Val (Bf)) for a in 1 : Af
122139 ]
@@ -136,15 +153,15 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
136153 f_or_f!y... ,
137154 DI. forward_backend (dense_backend),
138155 x,
139- batched_seeds_forward[ 1 ] ,
156+ batched_seed_forward_prep ,
140157 contexts... ;
141158 )
142159 pullback_prep = DI. prepare_pullback_nokwarg (
143160 strict,
144161 f_or_f!y... ,
145162 DI. reverse_backend (dense_backend),
146163 x,
147- batched_seeds_reverse[ 1 ] ,
164+ batched_seed_reverse_prep ,
148165 contexts... ;
149166 )
150167
@@ -156,6 +173,8 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
156173 coloring_result,
157174 compressed_matrix_forward,
158175 compressed_matrix_reverse,
176+ batched_seed_forward_prep,
177+ batched_seed_reverse_prep,
159178 batched_seeds_forward,
160179 batched_seeds_reverse,
161180 batched_results_forward,
@@ -183,6 +202,8 @@ function _sparse_jacobian_aux!(
183202 coloring_result,
184203 compressed_matrix_forward,
185204 compressed_matrix_reverse,
205+ batched_seed_forward_prep,
206+ batched_seed_reverse_prep,
186207 batched_seeds_forward,
187208 batched_seeds_reverse,
188209 batched_results_forward,
@@ -200,15 +221,15 @@ function _sparse_jacobian_aux!(
200221 pushforward_prep,
201222 DI. forward_backend (dense_backend),
202223 x,
203- batched_seeds_forward[ 1 ] ,
224+ batched_seed_forward_prep ,
204225 contexts... ,
205226 )
206227 pullback_prep_same = DI. prepare_pullback_same_point (
207228 f_or_f!y... ,
208229 pullback_prep,
209230 DI. reverse_backend (dense_backend),
210231 x,
211- batched_seeds_reverse[ 1 ] ,
232+ batched_seed_reverse_prep ,
212233 contexts... ,
213234 )
214235
0 commit comments