Skip to content
Closed
Show file tree
Hide file tree
Changes from 104 commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
a09db2e
refactor move files to algorithms/
Red-Portal Jun 6, 2025
08577cb
refactor move elbo-specific files into paramspacesgd/elbo/
Red-Portal Jun 6, 2025
7c53e63
refactor move elbo-specific exports and interfaces to elbo.jl
Red-Portal Jun 6, 2025
5c71bf4
add `step` for `ParamSpaceSGD`
Red-Portal Jun 10, 2025
f9560e7
increment version
Red-Portal Jun 10, 2025
b887cd2
fix subtyping error on ClipScale
Red-Portal Jun 18, 2025
64848fa
fix signature of `callback` in `ParamSpaceSGD`
Red-Portal Jun 18, 2025
116e22c
fix tests to match new interface
Red-Portal Jun 18, 2025
705748b
refactor restructure `test/` to match new structure in `src/`
Red-Portal Jun 18, 2025
29f9109
run formatter
Red-Portal Jun 18, 2025
db6b694
fix wrong path for test file
Red-Portal Jun 26, 2025
3609bc6
re-organized project
Red-Portal Jun 26, 2025
8751e31
run formatter
Red-Portal Jun 26, 2025
c53048a
fix tests to use update interface
Red-Portal Jun 26, 2025
95cbbda
bump AdvancedVI version in git submodules
Red-Portal Jun 26, 2025
8d728c4
add missing file
Red-Portal Jun 27, 2025
40648a4
update benchmarks to new interface
Red-Portal Jun 27, 2025
afd421e
fix wrong dollar sign usage
Red-Portal Jun 27, 2025
34a64ba
fix wrong interface
Red-Portal Jun 27, 2025
f03f6af
fix to new interface
Red-Portal Jun 27, 2025
2dc2e69
fix missing square in docstring of `ProximaLocationScaleEntropy`
Red-Portal Jun 27, 2025
218819b
fix typo
Red-Portal Jun 27, 2025
504f16e
add docstring for `AbstractAlgorithm`
Red-Portal Jun 27, 2025
557fd3d
move files in docs
Red-Portal Jun 27, 2025
056818e
fix docstring
Red-Portal Jun 27, 2025
0f3329f
fix docstrings
Red-Portal Jun 29, 2025
8670e5d
update documentation
Red-Portal Jun 29, 2025
e7f1885
add note to docstring of `ParamSpaceSGD`
Red-Portal Jun 29, 2025
183fb12
update docs for `RepGradELBO`
Red-Portal Jun 29, 2025
25be8d0
apply formatter
Red-Portal Jun 29, 2025
27f3634
apply formatter
Red-Portal Jun 29, 2025
5be9cca
apply formatter
Red-Portal Jun 29, 2025
5309ec7
add scoregradelbo to the list of objectives
Red-Portal Jun 29, 2025
f0eef5f
move `prob` argument to `optimize` from the constructor of `alg`
Red-Portal Jul 4, 2025
35f2d3a
run formatter
Red-Portal Jul 4, 2025
9e36ee3
fix remove unused import
Red-Portal Jul 4, 2025
a58c921
fix benchmark
Red-Portal Jul 4, 2025
9b5a893
fix docs
Red-Portal Jul 4, 2025
23ae1f8
fix docs
Red-Portal Jul 4, 2025
f0fd86b
add dependencies
Red-Portal Jul 7, 2025
744c9c5
add mixed ad log-density problem wrapper
Red-Portal Jul 7, 2025
1224652
update benchmarks
Red-Portal Jul 7, 2025
c6380cb
add Enzyme extension
Red-Portal Jul 7, 2025
54b1fff
fix type constraints in `RepGradELBO`
Red-Portal Jul 7, 2025
e8a672e
update tests, remove `DistributionsAD`
Red-Portal Jul 7, 2025
cd9d778
fix docs
Red-Portal Jul 7, 2025
5c02cb2
run formatter
Red-Portal Jul 7, 2025
d6db9af
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into mixed_ad
Red-Portal Jul 9, 2025
cbd4ed8
revert docs dependencies, add missing dep
Red-Portal Jul 9, 2025
00f5fc4
add deps to extensions
Red-Portal Jul 9, 2025
8612b1e
add missing deps in docs
Red-Portal Jul 9, 2025
070dc20
fix MooncakeExt
Red-Portal Jul 12, 2025
fd473ea
restructure CI for AD integration tests
Red-Portal Jul 13, 2025
d81ba5e
run formatter
Red-Portal Jul 13, 2025
29a24b2
modify CI
Red-Portal Jul 13, 2025
8f74ebf
fix tests
Red-Portal Jul 13, 2025
acf6c22
fix tests
Red-Portal Jul 13, 2025
050f363
fix name
Red-Portal Jul 13, 2025
a9c7dec
fix Enzyme
Red-Portal Jul 13, 2025
bab5e44
add missing dep for benchmarks
Red-Portal Jul 13, 2025
8837995
fix only optionally load ReverseDiff
Red-Portal Jul 13, 2025
839762d
try fixing Enzyme
Red-Portal Jul 13, 2025
86f0af7
run formatter
Red-Portal Jul 13, 2025
82f9307
restructure move AD integration tests into separate workflow
Red-Portal Jul 13, 2025
7757a9b
fix try fixing AD integration with Mooncake
Red-Portal Jul 13, 2025
75633f6
fix remove unused code
Red-Portal Jul 13, 2025
c56e6ad
change name for source of MixedADLogDensity
Red-Portal Jul 13, 2025
4fecea5
add tests for MixedADLogDensityProblem
Red-Portal Jul 13, 2025
499203c
fix renamed jobs in integration tests
Red-Portal Jul 13, 2025
9608db9
fix
Red-Portal Jul 13, 2025
43a6625
add test for without mixed ad
Red-Portal Jul 13, 2025
8cc6058
fix test for MixedADLogDensityProblem
Red-Portal Jul 13, 2025
f841eed
remove test
Red-Portal Jul 13, 2025
fd6a0ac
update docs
Red-Portal Jul 13, 2025
1accbf9
refactor test
Red-Portal Jul 13, 2025
30c8e15
add missing test
Red-Portal Jul 14, 2025
b17d661
revert interface changes to paramspacesgd
Red-Portal Jul 29, 2025
408f41d
remove dependency on LogDensityProblemsAD
Red-Portal Jul 29, 2025
5f9cac9
revert calls to `ADgradient` in tests
Red-Portal Jul 29, 2025
46b1f91
move Zygote import
Red-Portal Jul 29, 2025
907a8d4
revert remaining changes in tests
Red-Portal Jul 29, 2025
68c35bb
Revert "restructure move AD integration tests into separate workflow"
Red-Portal Jul 29, 2025
d84b273
fix revert renaming of CI.yml
Red-Portal Jul 29, 2025
52dd805
Revert "fix name"
Red-Portal Jul 29, 2025
8b911f4
fix revert Enzyme.yml
Red-Portal Jul 29, 2025
a2177ea
fix add back Enzyme.yml
Red-Portal Jul 29, 2025
019aa47
apply formatter
Red-Portal Jul 29, 2025
3a0b093
apply formatter
Red-Portal Jul 29, 2025
6776d55
revert non-essential changes to tests
Red-Portal Jul 29, 2025
2adf75c
fix remaining changes in tests
Red-Portal Jul 29, 2025
19547ab
fix revert necessary change to paramspacesgd interface
Red-Portal Jul 29, 2025
2e4396d
fix errors in tests
Red-Portal Jul 29, 2025
b8298a3
remove LogDensityProblemsAD dependency in benchmark
Red-Portal Jul 29, 2025
b7d9822
fix remove unused import in benchmark
Red-Portal Jul 29, 2025
27a2b83
run formatter
Red-Portal Jul 29, 2025
679b973
Merge branch 'mixed_ad' of github.com:TuringLang/AdvancedVI.jl into m…
Red-Portal Jul 29, 2025
e73b678
add missing compats
Red-Portal Jul 29, 2025
45408b1
fix revert changes to README
Red-Portal Jul 29, 2025
1fc1c49
revert change of the test order
Red-Portal Jul 30, 2025
7129532
use ReverseDiff for general/optimize.jl tests
Red-Portal Jul 30, 2025
ddb703c
fix Mooncake errors by removing wrong ReverseDiff specialization
Red-Portal Jul 30, 2025
2c7031b
fix remove call to `ADgradient`
Red-Portal Jul 30, 2025
623657c
fix AD test order
Red-Portal Jul 30, 2025
862f593
fix test order in `paramspacesgd/repgradelbo.jl`
Red-Portal Jul 30, 2025
fe1fa84
fix typo in warning message of capability check in repgradelbo
Red-Portal Jul 30, 2025
aaa39f9
fix capability of unconstrdist in benchmark
Red-Portal Jul 30, 2025
6cb9d7e
fix don't run benchmark on Enzyme
Red-Portal Jul 30, 2025
27b9f22
fix docs don't use LogDensityProblemsAD
Red-Portal Jul 30, 2025
d160060
fix order of AD benchmarks
Red-Portal Jul 30, 2025
37aefe3
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into mixed_ad
Red-Portal Jul 30, 2025
3bb03f9
apply formatter
Red-Portal Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.5.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -20,31 +21,44 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"]
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]

[compat]
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13, 0.14, 0.15"
ChainRulesCore = "1"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.6"
Random = "1"
ReverseDiff = "1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.10, 1.11.2"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
11 changes: 8 additions & 3 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ begin
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
("Mooncake", AutoMooncake(; config=Mooncake.Config())),
# ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)),
("ReverseDiff", AutoReverseDiff()),
(
"Enzyme",
AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
),
),
],
(familyname, family) in [
("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))),
Expand Down
9 changes: 8 additions & 1 deletion bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
return log_density_x + log_density_y
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end

function Bijectors.bijector(model::NormalLogNormal)
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Expand Down
52 changes: 32 additions & 20 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,41 +51,51 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2));
nothing
```

Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.
Some of the VI algorithms require gradients of the target log-density.
In this example, we will use `KLMinRepGradDescent`, which requires first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities).
For this, we can rely on `LogDensityProblemsAD`:

```@example elboexample
using Bijectors

function Bijectors.bijector(model::NormalLogNormal)
(; μ_x, σ_x, μ_y, Σ_y) = model
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:(1 + length(μ_y))],
)
end
using LogDensityProblemsAD
using ADTypes, ReverseDiff

b = Bijectors.bijector(model);
binv = inverse(b)
model_ad = ADgradient(AutoReverseDiff(), model)
nothing
```

Let's now load `AdvancedVI`.
Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`.
Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation.
Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`.
(This does not need to be the same as the backend used by `model_ad`.)
The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.

```@example elboexample
using Optimisers
using ADTypes, ForwardDiff
using AdvancedVI

alg = KLMinRepGradDescent(AutoReverseDiff());
nothing
```

We now need to select 1. a variational objective, and 2. a variational family.
Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support.
Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.

```@example elboexample
alg = KLMinRepGradDescent(AutoForwardDiff())
using Bijectors

function Bijectors.bijector(model::NormalLogNormal)
(; μ_x, σ_x, μ_y, Σ_y) = model
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:(1 + length(μ_y))],
)
end

b = Bijectors.bijector(model);
binv = inverse(b)
nothing
```

For the variational family, we will use the classic mean-field Gaussian family.
Expand All @@ -109,7 +119,9 @@ Passing `objective` and the initial variational approximation `q` to `optimize`

```@example elboexample
n_max_iter = 10^4
q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model, q0_trans; show_progress=false);
q_out, info, _ = AdvancedVI.optimize(
alg, n_max_iter, model_ad, q0_trans; show_progress=false
);
nothing
```

Expand Down
11 changes: 6 additions & 5 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ using ADTypes
using AdvancedVI
using Distributions
using LinearAlgebra
using LogDensityProblems
using LogDensityProblems, LogDensityProblemsAD
using Optimisers
using Plots
using ReverseDiff
using ForwardDiff, ReverseDiff

struct Target{D}
dist::D
Expand All @@ -163,6 +163,7 @@ D_true = Diagonal(log.(1 .+ exp.(randn(n_dims))))
Σsqrt_true = sqrt(Σ_true)
μ_true = randn(n_dims)
model = Target(MvNormal(μ_true, Σ_true));
model_ad = ADgradient(AutoForwardDiff(), model)

d = LogDensityProblems.dimension(model);
μ = zeros(d);
Expand All @@ -188,19 +189,19 @@ function callback(; params, averaged_params, restructure, kwargs...)
end

_, info_fr, _ = AdvancedVI.optimize(
alg, max_iter, model, q0_fr;
alg, max_iter, model_ad, q0_fr;
show_progress = false,
callback = callback,
);

_, info_mf, _ = AdvancedVI.optimize(
alg, max_iter, model, q0_mf;
alg, max_iter, model_ad, q0_mf;
show_progress = false,
callback = callback,
);

_, info_lr, _ = AdvancedVI.optimize(
alg, max_iter, model, q0_lr;
alg, max_iter, model_ad, q0_lr;
show_progress = false,
callback = callback,
);
Expand Down
38 changes: 24 additions & 14 deletions docs/src/paramspacesgd/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a bac
using Bijectors
using FillArrays
using LinearAlgebra
using LogDensityProblems
using LogDensityProblems, LogDensityProblemsAD
using Plots
using Random

using Optimisers
using ADTypes, ForwardDiff
using ADTypes, ForwardDiff, ReverseDiff
using AdvancedVI

struct NormalLogNormal{MX,SX,MY,SY}
Expand All @@ -150,12 +150,13 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
LogDensityProblems.LogDensityOrder{0}()
end

n_dims = 10
μ_x = 2.0
σ_x = 0.3
μ_y = Fill(2.0, n_dims)
σ_y = Fill(1.0, n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
n_dims = 10
μ_x = 2.0
σ_x = 0.3
μ_y = Fill(2.0, n_dims)
σ_y = Fill(1.0, n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
model_ad = ADgradient(AutoForwardDiff(), model)

d = LogDensityProblems.dimension(model);
μ = zeros(d);
Expand Down Expand Up @@ -185,7 +186,7 @@ binv = inverse(b)
q0_trans = Bijectors.TransformedDistribution(q0, binv)

cfe = KLMinRepGradDescent(
AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand All @@ -194,7 +195,7 @@ The repgradelbo estimator can instead be created as follows:

```@example repgradelbo
stl = KLMinRepGradDescent(
AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand All @@ -212,7 +213,7 @@ end
_, info_cfe, _ = AdvancedVI.optimize(
cfe,
max_iter,
model,
model_ad,
q0_trans;
show_progress = false,
callback = callback,
Expand All @@ -221,7 +222,16 @@ _, info_cfe, _ = AdvancedVI.optimize(
_, info_stl, _ = AdvancedVI.optimize(
stl,
max_iter,
model,
model_ad,
q0_trans;
show_progress = false,
callback = callback,
);

_, info_stl, _ = AdvancedVI.optimize(
stl,
max_iter,
model_ad,
q0_trans;
show_progress = false,
callback = callback,
Expand Down Expand Up @@ -302,9 +312,9 @@ nothing

```@setup repgradelbo
_, info_qmc, _ = AdvancedVI.optimize(
KLMinRepGradDescent(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
max_iter,
model,
model_ad,
q0_trans;
show_progress = false,
callback = callback,
Expand Down
13 changes: 13 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AdvancedVIEnzymeExt

using AdvancedVI
using LogDensityProblems
using Enzyme

Enzyme.@import_rrule(
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
AbstractVector
)

end
14 changes: 14 additions & 0 deletions ext/AdvancedVIMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module AdvancedVIMooncakeExt

using AdvancedVI
using Base: IEEEFloat
using LogDensityProblems
using Mooncake

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
Array{<:IEEEFloat,1},
}

end
11 changes: 11 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module AdvancedVIReverseDiffExt

using AdvancedVI
using LogDensityProblems
using ReverseDiff

ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(
prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray
)

end
3 changes: 3 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using LogDensityProblems
using ADTypes
using DiffResults
using DifferentiationInterface
using ChainRulesCore

using FillArrays

Expand Down Expand Up @@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some
"""
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

include("mixedad_logdensity.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

Expand Down
3 changes: 1 addition & 2 deletions src/algorithms/paramspacesgd/abstractobjective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

Expand All @@ -68,7 +68,6 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `params`.
- `obj_state`: Previous state of the objective.
Expand Down
Loading
Loading