Skip to content

Commit 579e77e

Browse files
committed
fix: avoid closures in batched_jacobian
1 parent 827a92a commit 579e77e

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

ext/LuxEnzymeExt/batched_autodiff.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
2-
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
1+
function Lux.AutoDiffInternalImpl.batched_jacobian_internal(
2+
f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F}
33
backend = normalize_backend(True(), ad)
4-
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x)
4+
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...)
55
end
66

77
function batched_enzyme_jacobian_impl(
8-
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G}
8+
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G}
99
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
1010
y = f_orig(x)
1111
f = annotate_function(ad, f_orig)
@@ -26,7 +26,8 @@ function batched_enzyme_jacobian_impl(
2626
for i in 1:chunk_size:(length(x) ÷ B)
2727
idxs = i:min(i + chunk_size - 1, length(x) ÷ B)
2828
partials′ = make_onehot!(partials, idxs)
29-
J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′)))
29+
J_partials = only(Enzyme.autodiff(
30+
ad.mode, f, BatchDuplicated(x, partials′), Const.(args)...))
3031
for (idx, J_partial) in zip(idxs, J_partials)
3132
copyto!(view(J, :, idx, :), reshape(J_partial, :, B))
3233
end
@@ -36,7 +37,7 @@ function batched_enzyme_jacobian_impl(
3637
end
3738

3839
function batched_enzyme_jacobian_impl(
39-
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G}
40+
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G}
4041
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
4142
y = f_orig(x)
4243

@@ -60,7 +61,8 @@ function batched_enzyme_jacobian_impl(
6061
partials′ = make_onehot!(partials, idxs)
6162
J_partials′ = make_zero!(J_partials, idxs)
6263
Enzyme.autodiff(
63-
ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
64+
ad.mode, fn, BatchDuplicated(y, partials′),
65+
BatchDuplicated(x, J_partials′), Const.(args)...
6466
)
6567
for (idx, J_partial) in zip(idxs, J_partials)
6668
copyto!(view(J, idx, :, :), reshape(J_partial, :, B))

src/autodiff/nested_autodiff.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
## Written like this to avoid dynamic dispatch from Zygote
22
# Input Gradient / Jacobian
33
function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F}
4-
(f, f.inner.ps)
4+
return f, f.inner.ps
55
end
66
function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F}
7-
(@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps)
7+
return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps
88
end
99
rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps
1010

@@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer})
2222
end
2323

2424
## Break ambiguity
25-
for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
25+
for op in [
26+
ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
2627
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer},
2728
ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}},
28-
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}]
29+
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}
30+
]
2931
@eval function rewrite_autodiff_call(::$op)
3032
error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \
3133
layers")

0 commit comments

Comments
 (0)