Skip to content

Commit b749dca

Browse files
committed
test: add batched jacobian tests for enzyme
1 parent 579e77e commit b749dca

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

test/autodiff/batched_autodiff_tests.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,41 @@
11
@testitem "Batched Jacobian" setup=[SharedTestSetup] tags=[:autodiff] begin
2-
using ComponentArrays, ForwardDiff, Zygote
2+
using ComponentArrays, ForwardDiff, Zygote, ADTypes
33

44
rng = StableRNG(12345)
55

66
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
77
models = (
8-
Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()),
9-
Conv((3, 3), 4 => 2, gelu; pad=SamePad()), FlattenLayer(), Dense(18 => 2)),
10-
Chain(Dense(2, 4, gelu), Dense(4, 2)))
8+
Chain(
9+
Conv((3, 3), 2 => 4, gelu; pad=SamePad()),
10+
Conv((3, 3), 4 => 2, gelu; pad=SamePad()),
11+
FlattenLayer(), Dense(18 => 2)
12+
),
13+
Chain(Dense(2, 4, gelu), Dense(4, 2))
14+
)
1115
Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4)))
1216

13-
for (model, X) in zip(models, Xs)
17+
for (i, (model, X)) in enumerate(zip(models, Xs))
1418
ps, st = Lux.setup(rng, model) |> dev
1519
smodel = StatefulLuxLayer{true}(model, ps, st)
1620

1721
J1 = allow_unstable() do
1822
ForwardDiff.jacobian(smodel, X)
1923
end
2024

21-
@testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff())
25+
@testset for backend in (
26+
AutoZygote(), AutoForwardDiff(),
27+
AutoEnzyme(;
28+
mode=Enzyme.Forward, function_annotation=Enzyme.Const
29+
),
30+
AutoEnzyme(;
31+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
32+
function_annotation=Enzyme.Const
33+
)
34+
)
35+
# Forward rules for Enzyme is currently not implemented for several Ops
36+
i == 1 && backend isa AutoEnzyme &&
37+
ADTypes.mode(backend) isa ADTypes.ForwardMode && continue
38+
2239
J2 = allow_unstable() do
2340
batched_jacobian(smodel, backend, X)
2441
end
@@ -40,7 +57,14 @@
4057
end
4158

4259
@testset "Issue #636 Chunksize Specialization" begin
43-
for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff())
60+
for N in (2, 4, 8, 11, 12, 50, 51),
61+
backend in (
62+
AutoZygote(), AutoForwardDiff(), AutoEnzyme(),
63+
AutoEnzyme(; mode=Enzyme.Reverse)
64+
)
65+
66+
ongpu && backend isa AutoEnzyme && continue
67+
4468
model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x
4569
@return allow_unstable() do
4670
batched_jacobian(potential, backend, x)
@@ -78,6 +102,13 @@
78102
end
79103
@test Jx_zygote Jx_true
80104

105+
if !ongpu
106+
Jx_enzyme = allow_unstable() do
107+
batched_jacobian(ftest, AutoEnzyme(), x)
108+
end
109+
@test Jx_enzyme Jx_true
110+
end
111+
81112
fincorrect(x) = x[:, 1]
82113
x = reshape(Float32.(1:6), 2, 3) |> dev
83114

test/autodiff/nested_autodiff_tests.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,8 @@ end
267267
end
268268

269269
@test_gradients(__f, x,
270-
ps;
271-
atol=1.0f-3,
272-
rtol=1.0f-3,
273-
skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
270+
ps; atol=1.0f-3,
271+
rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
274272
end
275273
end
276274
end
@@ -409,6 +407,5 @@ end
409407
end
410408

411409
@test_gradients(__f, x, ps; atol=1.0f-3,
412-
rtol=1.0f-3,
413-
skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
410+
rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
414411
end

0 commit comments

Comments
 (0)