diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 5e5f11766..44d26be44 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -52,7 +52,7 @@ begin ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), #("Mooncake", AutoMooncake(; config=Mooncake.Config())), - #("Enzyme", AutoEnzyme()), + #("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), ], (familyname, family) in [ ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index c0e3f6db9..0b9782fa5 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,6 +1,10 @@ - AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) else Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 6467681f4..e0e7476b5 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,6 +1,10 @@ - AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) else Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 7adba6398..24edf5678 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,6 +1,10 @@ - AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) else Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index cea10b621..e58d479f6 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -1,6 +1,10 @@ - AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) else Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index ed88ca086..734a7dbe7 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -1,6 +1,10 @@ - AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) else Dict( :ForwarDiff => AutoForwardDiff(),