|
1 | 1 | @testitem "Batched Jacobian" setup=[SharedTestSetup] tags=[:autodiff] begin |
2 | | - using ComponentArrays, ForwardDiff, Zygote |
| 2 | + using ComponentArrays, ForwardDiff, Zygote, ADTypes |
3 | 3 |
|
4 | 4 | rng = StableRNG(12345) |
5 | 5 |
|
6 | 6 | @testset "$mode" for (mode, aType, dev, ongpu) in MODES |
7 | 7 | models = ( |
8 | | - Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), |
9 | | - Conv((3, 3), 4 => 2, gelu; pad=SamePad()), FlattenLayer(), Dense(18 => 2)), |
10 | | - Chain(Dense(2, 4, gelu), Dense(4, 2))) |
| 8 | + Chain( |
| 9 | + Conv((3, 3), 2 => 4, gelu; pad=SamePad()), |
| 10 | + Conv((3, 3), 4 => 2, gelu; pad=SamePad()), |
| 11 | + FlattenLayer(), Dense(18 => 2) |
| 12 | + ), |
| 13 | + Chain(Dense(2, 4, gelu), Dense(4, 2)) |
| 14 | + ) |
11 | 15 | Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4))) |
12 | 16 |
|
13 | | - for (model, X) in zip(models, Xs) |
| 17 | + for (i, (model, X)) in enumerate(zip(models, Xs)) |
14 | 18 | ps, st = Lux.setup(rng, model) |> dev |
15 | 19 | smodel = StatefulLuxLayer{true}(model, ps, st) |
16 | 20 |
|
17 | 21 | J1 = allow_unstable() do |
18 | 22 | ForwardDiff.jacobian(smodel, X) |
19 | 23 | end |
20 | 24 |
|
21 | | - @testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff()) |
| 25 | + @testset for backend in ( |
| 26 | + AutoZygote(), AutoForwardDiff(), |
| 27 | + AutoEnzyme(; |
| 28 | + mode=Enzyme.Forward, function_annotation=Enzyme.Const |
| 29 | + ), |
| 30 | + AutoEnzyme(; |
| 31 | + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), |
| 32 | + function_annotation=Enzyme.Const |
| 33 | + ) |
| 34 | + ) |
| 35 | + # Forward rules for Enzyme is currently not implemented for several Ops |
| 36 | + i == 1 && backend isa AutoEnzyme && |
| 37 | + ADTypes.mode(backend) isa ADTypes.ForwardMode && continue |
| 38 | + |
22 | 39 | J2 = allow_unstable() do |
23 | 40 | batched_jacobian(smodel, backend, X) |
24 | 41 | end |
|
40 | 57 | end |
41 | 58 |
|
42 | 59 | @testset "Issue #636 Chunksize Specialization" begin |
43 | | - for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff()) |
| 60 | + for N in (2, 4, 8, 11, 12, 50, 51), |
| 61 | + backend in ( |
| 62 | + AutoZygote(), AutoForwardDiff(), AutoEnzyme(), |
| 63 | + AutoEnzyme(; mode=Enzyme.Reverse) |
| 64 | + ) |
| 65 | + |
| 66 | + ongpu && backend isa AutoEnzyme && continue |
| 67 | + |
44 | 68 | model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x |
45 | 69 | @return allow_unstable() do |
46 | 70 | batched_jacobian(potential, backend, x) |
|
78 | 102 | end |
79 | 103 | @test Jx_zygote ≈ Jx_true |
80 | 104 |
|
| 105 | + if !ongpu |
| 106 | + Jx_enzyme = allow_unstable() do |
| 107 | + batched_jacobian(ftest, AutoEnzyme(), x) |
| 108 | + end |
| 109 | + @test Jx_enzyme ≈ Jx_true |
| 110 | + end |
| 111 | + |
81 | 112 | fincorrect(x) = x[:, 1] |
82 | 113 | x = reshape(Float32.(1:6), 2, 3) |> dev |
83 | 114 |
|
|
0 commit comments