diff --git a/docs/src/api.md b/docs/src/api.md index 885d587ea..93a486a9e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -122,9 +122,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au | Exported symbol | Documentation | Description | |:----------------- |:------------------------------------ |:---------------------- | +| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | -| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | +| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index 58a58eb2a..d931fb592 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -23,7 +23,7 @@ using Printf: Printf using Random: Random using LinearAlgebra: I -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff() @@ -124,6 +124,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoMooncake, + AutoEnzyme, # Debugging - Turing setprogress!, # Distributions diff --git a/test/ad.jl b/test/ad.jl index 9524199dc..9f80fbd6a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -10,21 +10,28 @@ using Test using ..Models: gdemo_default import ForwardDiff, ReverseDiff -# Detect if prerelease version, if so, we skip some tests -const IS_PRERELEASE = !isempty(VERSION.prerelease) -const INCLUDE_MOONCAKE = !IS_PRERELEASE +# Detect if 1.12, if so, we skip some tests +const IS_112 = VERSION >= v"1.12.0" +const INCLUDE_MOONCAKE = !IS_112 if INCLUDE_MOONCAKE import Pkg Pkg.add("Mooncake") using Mooncake: Mooncake end +const INCLUDE_ENZYME = !IS_112 +if INCLUDE_ENZYME + import Pkg + Pkg.add("Enzyme") + using Enzyme: Enzyme +end + """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) """A dictionary mapping ADTypes to the element types they use.""" -eltypes_by_adtype = Dict( +eltypes_by_adtype = Dict{Type,Tuple}( AutoForwardDiff => (ForwardDiff.Dual,), AutoReverseDiff => ( ReverseDiff.TrackedArray, @@ -39,6 +46,9 @@ eltypes_by_adtype = Dict( if INCLUDE_MOONCAKE eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,) end +if INCLUDE_ENZYME + eltypes_by_adtype[AutoEnzyme] = () +end """ AbstractWrongADBackendError @@ -183,6 +193,22 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)] if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end +if INCLUDE_ENZYME + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation=Enzyme.Const, + ), + ) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +end # Check that ADTypeCheckContext itself works as expected. @testset "ADTypeCheckContext" begin