1- function Lux. AutoDiffInternalImpl. batched_jacobian_impl (
2- f:: F , ad:: AutoEnzyme , x:: AbstractArray ) where {F}
1+ function Lux. AutoDiffInternalImpl. batched_jacobian_internal (
2+ f:: F , ad:: AutoEnzyme , x:: AbstractArray , args ... ) where {F}
33 backend = normalize_backend (True (), ad)
4- return batched_enzyme_jacobian_impl (f, backend, ADTypes. mode (backend), x)
4+ return batched_enzyme_jacobian_impl (f, backend, ADTypes. mode (backend), x, args ... )
55end
66
77function batched_enzyme_jacobian_impl (
8- f_orig:: G , ad:: AutoEnzyme , :: ForwardMode , x:: AbstractArray ) where {G}
8+ f_orig:: G , ad:: AutoEnzyme , :: ForwardMode , x:: AbstractArray , args ... ) where {G}
99 # We need to run the function once to get the output type. Can we use ForwardWithPrimal?
1010 y = f_orig (x)
1111 f = annotate_function (ad, f_orig)
@@ -26,7 +26,8 @@ function batched_enzyme_jacobian_impl(
2626 for i in 1 : chunk_size: (length (x) ÷ B)
2727 idxs = i: min (i + chunk_size - 1 , length (x) ÷ B)
2828 partials′ = make_onehot! (partials, idxs)
29- J_partials = only (Enzyme. autodiff (ad. mode, f, BatchDuplicated (x, partials′)))
29+ J_partials = only (Enzyme. autodiff (
30+ ad. mode, f, BatchDuplicated (x, partials′), Const .(args)... ))
3031 for (idx, J_partial) in zip (idxs, J_partials)
3132 copyto! (view (J, :, idx, :), reshape (J_partial, :, B))
3233 end
@@ -36,7 +37,7 @@ function batched_enzyme_jacobian_impl(
3637end
3738
3839function batched_enzyme_jacobian_impl (
39- f_orig:: G , ad:: AutoEnzyme , :: ReverseMode , x:: AbstractArray ) where {G}
40+ f_orig:: G , ad:: AutoEnzyme , :: ReverseMode , x:: AbstractArray , args ... ) where {G}
4041 # We need to run the function once to get the output type. Can we use ReverseWithPrimal?
4142 y = f_orig (x)
4243
@@ -60,7 +61,8 @@ function batched_enzyme_jacobian_impl(
6061 partials′ = make_onehot! (partials, idxs)
6162 J_partials′ = make_zero! (J_partials, idxs)
6263 Enzyme. autodiff (
63- ad. mode, fn, BatchDuplicated (y, partials′), BatchDuplicated (x, J_partials′)
64+ ad. mode, fn, BatchDuplicated (y, partials′),
65+ BatchDuplicated (x, J_partials′), Const .(args)...
6466 )
6567 for (idx, J_partial) in zip (idxs, J_partials)
6668 copyto! (view (J, idx, :, :), reshape (J_partial, :, B))
0 commit comments