Skip to content

Commit f406fa6

Browse files
committed
feat: batched_jacobian for Reactant [skip ci]
1 parent 5f4a4cb commit f406fa6

File tree

8 files changed

+126
-32
lines changed

8 files changed

+126
-32
lines changed

ext/LuxEnzymeExt/LuxEnzymeExt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@ using MLDataDevices: isleaf
1313

1414
Lux.is_extension_loaded(::Val{:Enzyme}) = true
1515

16-
normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
17-
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Forward)
18-
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Reverse)
19-
20-
annotate_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
21-
annotate_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)
22-
2316
struct OOPFunctionWrapper{F}
2417
f::F
2518
end

ext/LuxEnzymeExt/autodiff.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# VJPs
22

33
function _vector_jacobian_product_impl(f::F, ad::AutoEnzyme, x, v, extra_args...) where {F}
4-
ad = normalize_backend(False(), ad)
4+
ad = Utils.normalize_autoenzyme_mode(Reverse, ad)
55
@assert ADTypes.mode(ad) isa ADTypes.ReverseMode "VJPs are only supported in reverse \
66
mode"
77
dx = fmap(zero, x; exclude=isleaf)
88
Enzyme.autodiff(
99
ad.mode,
10-
annotate_function(ad, OOPFunctionWrapper(f)),
10+
Utils.annotate_enzyme_function(ad, OOPFunctionWrapper(f)),
1111
Duplicated(fmap(similar, v; exclude=isleaf), fmap(copy, v; exclude=isleaf)),
1212
Duplicated(x, dx),
1313
extra_args...,
@@ -30,11 +30,13 @@ end
3030
# JVPs
3131

3232
function _jacobian_vector_product_impl(f::F, ad::AutoEnzyme, x, u, extra_args...) where {F}
33-
ad = normalize_backend(True(), ad)
33+
ad = Utils.normalize_autoenzyme_mode(Forward, ad)
3434
@assert ADTypes.mode(ad) isa ADTypes.ForwardMode "JVPs are only supported in forward \
3535
mode"
3636
return only(
37-
Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), extra_args...)
37+
Enzyme.autodiff(
38+
ad.mode, Utils.annotate_enzyme_function(ad, f), Duplicated(x, u), extra_args...
39+
),
3840
)
3941
end
4042

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
module LuxReactantExt
22

3+
using ADTypes: ADTypes, AutoEnzyme
34
using Enzyme: Enzyme, Const
5+
using EnzymeCore: EnzymeCore
6+
using LinearAlgebra: LinearAlgebra
47
using Preferences: load_preference
58
using Optimisers: Optimisers
69
using Reactant:
7-
Reactant,
8-
Profiler,
9-
@compile,
10-
@code_hlo,
11-
@jit,
12-
@opcall,
13-
AnyTracedRArray,
14-
TracedRArray,
15-
TracedRNumber,
16-
PrecisionConfig
10+
Reactant, Profiler, AnyTracedRArray, TracedRArray, TracedRNumber, PrecisionConfig
11+
using Reactant: @compile, @code_hlo, @jit, @opcall
1712
using ReactantCore: ReactantCore, @trace
1813
using Setfield: @set!
1914
using Static: True, False
@@ -74,5 +69,6 @@ include("training.jl")
7469
include("layers.jl")
7570
include("tracing.jl")
7671
include("saved_model.jl")
72+
include("batched_jacobian.jl")
7773

7874
end
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
2+
f::F, ad::Lux.Training.ReactantBackend, x
3+
) where {F}
4+
ad = Utils.normalize_autoenzyme_mode(EnzymeCore.Forward, ad.ad)
5+
if ADTypes.mode(ad) isa ADTypes.ReverseMode
6+
return _batched_jacobian_reverse_impl(f, ad, x)
7+
else
8+
return _batched_jacobian_forward_impl(f, ad, x)
9+
end
10+
end
11+
12+
struct ApplyWithReshape{F,SZ}
13+
f::F
14+
sz::SZ
15+
end
16+
17+
(f::ApplyWithReshape)(x) = f.f(reshape(x, f.sz))
18+
19+
function (f::ApplyWithReshape)(y, x)
20+
res = f.f(reshape(x, f.sz))
21+
copyto!(y, reshape(res, size(y)))
22+
return nothing
23+
end
24+
25+
function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
26+
y = f(x)
27+
@assert y isa AbstractArray
28+
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
29+
throw(AssertionError("`batched_jacobian` only supports batched outputs \
30+
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
31+
end
32+
33+
f′ = ApplyWithReshape(f, size(x))
34+
35+
y = Utils.contiguous(reshape(y, :, size(y, ndims(y))))
36+
dy = repeat(
37+
reshape(
38+
Reactant.promote_to(
39+
TracedRArray{Reactant.unwrapped_eltype(y),2}, LinearAlgebra.I(size(y, 1))
40+
),
41+
size(y, 1),
42+
1,
43+
size(y, 1),
44+
),
45+
1,
46+
size(y, 2),
47+
1,
48+
)
49+
dy = Utils.contiguous(dy)
50+
51+
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
52+
dx = similar(x, size(x, 1), size(x, 2), size(y, 1))
53+
fill!(dx, false)
54+
55+
Enzyme.autodiff(
56+
ad.mode,
57+
Utils.annotate_enzyme_function(ad, f′),
58+
Reactant.StackedBatchDuplicated(y, dy),
59+
Reactant.StackedBatchDuplicated(x, dx),
60+
)
61+
62+
return permutedims(dx, (3, 1, 2))
63+
end
64+
65+
function _batched_jacobian_forward_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
66+
f′ = ApplyWithReshape(f, size(x))
67+
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
68+
69+
bx = repeat(
70+
reshape(
71+
Reactant.promote_to(
72+
TracedRArray{Reactant.unwrapped_eltype(x),2}, LinearAlgebra.I(size(x, 1))
73+
),
74+
size(x, 1),
75+
1,
76+
size(x, 1),
77+
),
78+
1,
79+
size(x, 2),
80+
1,
81+
)
82+
bx = Utils.contiguous(bx)
83+
84+
return stack(
85+
only(
86+
Enzyme.autodiff(
87+
ad.mode,
88+
Utils.annotate_enzyme_function(ad, f′),
89+
Reactant.StackedBatchDuplicated(x, bx),
90+
),
91+
);
92+
dims=2,
93+
)
94+
end

src/autodiff/api.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ the following properties for `y = f(x)`:
9999
100100
| Supported Backends | Packages Needed |
101101
|:------------------ |:--------------- |
102+
| `AutoEnzyme` | `Reactant.jl` |
102103
| `AutoForwardDiff` | |
103104
| `AutoZygote` | `Zygote.jl` |
104105
@@ -126,16 +127,13 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
126127
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
127128
end
128129

129-
function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
130-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
131-
end
132-
133-
function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
134-
if !is_extension_loaded(Val(:Zygote))
135-
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
136-
`$(backend)`.")
130+
for implemented_backend in (:AutoForwardDiff, :AutoZygote, :AutoEnzyme)
131+
@eval function batched_jacobian(
132+
f::F, backend::$implemented_backend, x::AbstractArray
133+
) where {F}
134+
assert_backend_loaded(:batched_jacobian, backend)
135+
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
137136
end
138-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
139137
end
140138

141139
# Utils

src/autodiff/batched_autodiff.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ end
9090
function batched_jacobian_internal(
9191
f::F, backend::AbstractADType, x::AbstractArray
9292
) where {F}
93-
return batched_jacobian_impl(f, backend, x)
93+
return batched_jacobian_impl(
94+
f, Lux.Training.maybe_wrap_adtype(backend, get_device_type(x)), x
95+
)
9496
end
9597

9698
# ForwardDiff.jl Implementation

src/helpers/training.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ end
107107

108108
@concrete struct ReactantBackend
109109
return_gradients <: StaticBool
110+
ad <: AutoEnzyme
110111
end
111112

112113
const APPLY_GRAD_DOCSTRING = """
@@ -210,7 +211,7 @@ maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
210211
function maybe_wrap_adtype(
211212
ad::AbstractADType, ::Type{ReactantDevice}; return_gradients::Utils.BoolType=True()
212213
)
213-
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients))
214+
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), ad)
214215
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
215216
Enzyme.jl (`AutoEnzyme`)."))
216217
end

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Utils
22

3+
using ADTypes: ADTypes, AutoEnzyme
34
using ArrayInterface: ArrayInterface
45
using ArgCheck: @argcheck
56
using ChainRulesCore: ChainRulesCore, @non_differentiable, NoTangent
@@ -8,6 +9,7 @@ using EnzymeCore: EnzymeRules
89
using ForwardDiff: Dual
910
using Functors: Functors, fmapstructure
1011
using Random: AbstractRNG
12+
using Setfield: @set
1113
using Static: Static, StaticBool, StaticInteger, StaticSymbol
1214
using StaticArraysCore: SMatrix, SVector
1315

@@ -237,6 +239,12 @@ calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4
237239

238240
recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=MLDataDevices.isleaf)
239241

242+
normalize_autoenzyme_mode(mode, ad::AutoEnzyme) = ad
243+
normalize_autoenzyme_mode(mode, ad::AutoEnzyme{Nothing}) = @set(ad.mode = mode)
244+
245+
annotate_enzyme_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
246+
annotate_enzyme_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)
247+
240248
end
241249

242250
using .Utils:

0 commit comments

Comments
 (0)