Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Optimisers = "0.4.6"
PrecompileTools = "1.2.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.170"
Reactant = "0.2.171"
ReactantCore = "0.1.16"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Expand Down
7 changes: 0 additions & 7 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ using MLDataDevices: isleaf

Lux.is_extension_loaded(::Val{:Enzyme}) = true

normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Reverse)

annotate_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
annotate_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)

struct OOPFunctionWrapper{F}
f::F
end
Expand Down
10 changes: 6 additions & 4 deletions ext/LuxEnzymeExt/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# VJPs

function _vector_jacobian_product_impl(f::F, ad::AutoEnzyme, x, v, extra_args...) where {F}
ad = normalize_backend(False(), ad)
ad = Utils.normalize_autoenzyme_mode(Reverse, ad)
@assert ADTypes.mode(ad) isa ADTypes.ReverseMode "VJPs are only supported in reverse \
mode"
dx = fmap(zero, x; exclude=isleaf)
Enzyme.autodiff(
ad.mode,
annotate_function(ad, OOPFunctionWrapper(f)),
Utils.annotate_enzyme_function(ad, OOPFunctionWrapper(f)),
Duplicated(fmap(similar, v; exclude=isleaf), fmap(copy, v; exclude=isleaf)),
Duplicated(x, dx),
extra_args...,
Expand All @@ -30,11 +30,13 @@ end
# JVPs

function _jacobian_vector_product_impl(f::F, ad::AutoEnzyme, x, u, extra_args...) where {F}
ad = normalize_backend(True(), ad)
ad = Utils.normalize_autoenzyme_mode(Forward, ad)
@assert ADTypes.mode(ad) isa ADTypes.ForwardMode "JVPs are only supported in forward \
mode"
return only(
Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), extra_args...)
Enzyme.autodiff(
ad.mode, Utils.annotate_enzyme_function(ad, f), Duplicated(x, u), extra_args...
),
)
end

Expand Down
16 changes: 6 additions & 10 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
module LuxReactantExt

using ADTypes: ADTypes, AutoEnzyme
using Enzyme: Enzyme, Const
using EnzymeCore: EnzymeCore
using LinearAlgebra: LinearAlgebra
using Preferences: load_preference
using Optimisers: Optimisers
using Reactant:
Reactant,
Profiler,
@compile,
@code_hlo,
@jit,
@opcall,
AnyTracedRArray,
TracedRArray,
TracedRNumber,
PrecisionConfig
Reactant, Profiler, AnyTracedRArray, TracedRArray, TracedRNumber, PrecisionConfig
using Reactant: @compile, @code_hlo, @jit, @opcall
using ReactantCore: ReactantCore, @trace
using Setfield: @set!
using Static: True, False
Expand Down Expand Up @@ -74,5 +69,6 @@ include("training.jl")
include("layers.jl")
include("tracing.jl")
include("saved_model.jl")
include("batched_jacobian.jl")

end
94 changes: 94 additions & 0 deletions ext/LuxReactantExt/batched_jacobian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
f::F, ad::Lux.Training.ReactantBackend, x
) where {F}
ad = Utils.normalize_autoenzyme_mode(EnzymeCore.Forward, ad.ad)
if ADTypes.mode(ad) isa ADTypes.ReverseMode
return _batched_jacobian_reverse_impl(f, ad, x)
else
return _batched_jacobian_forward_impl(f, ad, x)
end
end

struct ApplyWithReshape{F,SZ}
f::F
sz::SZ
end

(f::ApplyWithReshape)(x) = reshape(f.f(reshape(x, f.sz)), :, size(x, ndims(x)))

function (f::ApplyWithReshape)(y, x)
res = f.f(reshape(x, f.sz))
copyto!(y, reshape(res, size(y)))
return nothing
end

function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
y = f(x)
@assert y isa AbstractArray
if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end

f′ = ApplyWithReshape(f, size(x))

y = Utils.contiguous(reshape(y, :, size(y, ndims(y))))
dy = repeat(
reshape(
Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(y),2}, LinearAlgebra.I(size(y, 1))
),
size(y, 1),
1,
size(y, 1),
),
1,
size(y, 2),
1,
)
dy = Utils.contiguous(dy)

x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
dx = similar(x, size(x, 1), size(x, 2), size(y, 1))
fill!(dx, false)

Enzyme.autodiff(
ad.mode,
Utils.annotate_enzyme_function(ad, f′),
Reactant.StackedBatchDuplicated(y, dy),
Reactant.StackedBatchDuplicated(x, dx),
)

return permutedims(dx, (3, 1, 2))
end

function _batched_jacobian_forward_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
f′ = ApplyWithReshape(f, size(x))
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))

bx = repeat(
reshape(
Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(x),2}, LinearAlgebra.I(size(x, 1))
),
size(x, 1),
1,
size(x, 1),
),
1,
size(x, 2),
1,
)
bx = Utils.contiguous(bx)

return stack(
only(
Enzyme.autodiff(
ad.mode,
Utils.annotate_enzyme_function(ad, f′),
Reactant.StackedBatchDuplicated(x, bx),
),
);
dims=2,
)
end
16 changes: 7 additions & 9 deletions src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ the following properties for `y = f(x)`:

| Supported Backends | Packages Needed |
|:------------------ |:--------------- |
| `AutoEnzyme` | `Reactant.jl` |
| `AutoForwardDiff` | |
| `AutoZygote` | `Zygote.jl` |

Expand Down Expand Up @@ -126,16 +127,13 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
end

function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
end

function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
if !is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
`$(backend)`.")
for implemented_backend in (:AutoForwardDiff, :AutoZygote, :AutoEnzyme)
@eval function batched_jacobian(
f::F, backend::$implemented_backend, x::AbstractArray
) where {F}
assert_backend_loaded(:batched_jacobian, backend)
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
end
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
end

# Utils
Expand Down
4 changes: 3 additions & 1 deletion src/autodiff/batched_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ end
function batched_jacobian_internal(
f::F, backend::AbstractADType, x::AbstractArray
) where {F}
return batched_jacobian_impl(f, backend, x)
return batched_jacobian_impl(
f, Lux.Training.maybe_wrap_adtype(backend, get_device_type(x)), x
)
end

# ForwardDiff.jl Implementation
Expand Down
3 changes: 2 additions & 1 deletion src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ end
@concrete struct ReactantBackend
return_gradients <: StaticBool
sync::Bool
ad <: AutoEnzyme
end

const APPLY_GRAD_DOCSTRING = """
Expand Down Expand Up @@ -285,7 +286,7 @@ function maybe_wrap_adtype(
return_gradients::Utils.BoolType=True(),
sync::Bool=false,
)
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync)
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync, ad)
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Utils

using ADTypes: ADTypes, AutoEnzyme
using ArrayInterface: ArrayInterface
using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore, @non_differentiable, NoTangent
Expand All @@ -8,6 +9,7 @@ using EnzymeCore: EnzymeRules
using ForwardDiff: Dual
using Functors: Functors, fmapstructure
using Random: AbstractRNG
using Setfield: @set
using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

Expand Down Expand Up @@ -237,6 +239,12 @@ calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4

recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=MLDataDevices.isleaf)

normalize_autoenzyme_mode(mode, ad::AutoEnzyme) = ad
normalize_autoenzyme_mode(mode, ad::AutoEnzyme{Nothing}) = @set(ad.mode = mode)

annotate_enzyme_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
annotate_enzyme_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)

end

using .Utils:
Expand Down
28 changes: 28 additions & 0 deletions test/reactant/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,31 @@
end
end
end

@testitem "AutoDiff APIs: Batched Jacobian" tags = [:reactant] setup = [SharedTestSetup] begin
using Reactant, Lux, Zygote, Random, ForwardDiff, Enzyme

fn(x) = reshape(sum(abs2, x; dims=(1, 2, 3)), 1, :)

rng = Random.default_rng()

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

dev = reactant_device(; force=true)

x = rand(rng, Float32, 2, 3, 4, 5)
x_ra = dev(x)

# TODO: ....
end
end
Loading