Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4bf2fce
remove the bijectors extension and relevant tests, benchmarks
Red-Portal Nov 21, 2025
9dc43bb
fix remove use of Bijectors in tests
Red-Portal Nov 21, 2025
9479ed4
fix typo in HISTORY
Red-Portal Nov 21, 2025
0142a82
bump AdvancedVI version
Red-Portal Nov 21, 2025
bb47ced
run formatter
Red-Portal Nov 21, 2025
6fa2bc3
run formatter
Red-Portal Nov 21, 2025
d6157f1
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal Nov 21, 2025
e201345
fix removed include to removed file
Red-Portal Nov 21, 2025
66a22a5
fix remove erroenous references to bijector
Red-Portal Nov 21, 2025
f5f7fcb
fix missing comma
Red-Portal Nov 21, 2025
7a6e902
fix remove include to removed file
Red-Portal Nov 21, 2025
2a1755a
update READMe
Red-Portal Nov 21, 2025
f23eb7a
fix move benchmark model to main file
Red-Portal Nov 21, 2025
4d8d95e
update docs to the new recommended use of Bijectors
Red-Portal Nov 22, 2025
fba3d9f
run formatter
Red-Portal Nov 22, 2025
4daaa79
run formatter
Red-Portal Nov 22, 2025
a0f13d5
run formatter
Red-Portal Nov 22, 2025
9a64db7
run formatter
Red-Portal Nov 22, 2025
048a310
add constraint tutorial
Red-Portal Nov 22, 2025
69ae57a
fix missing import in docs
Red-Portal Nov 22, 2025
af4ad18
run furmatter to constrained
Red-Portal Nov 22, 2025
0386029
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal Nov 22, 2025
f9d7f0b
fix typo
Red-Portal Nov 22, 2025
d6dede2
fix bins in normalizing flow tutorial
Red-Portal Nov 22, 2025
c7e8c44
fix formatting
Red-Portal Nov 22, 2025
373abdd
run formatter
Red-Portal Nov 22, 2025
6a5c7ba
revert changes to documentation
Red-Portal Nov 22, 2025
c05e2f5
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal Nov 22, 2025
8c657e0
revert changes to the README
Red-Portal Nov 22, 2025
92f5077
Revert "revert changes to the README"
Red-Portal Nov 25, 2025
f81398a
Revert "revert changes to documentation"
Red-Portal Nov 25, 2025
6bf17b3
run formatter
Red-Portal Nov 25, 2025
8c729c5
fix use more sophisticated `TransformedLogDensityProblem`
Red-Portal Nov 25, 2025
db15461
fix wrong type of markdown env
Red-Portal Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Release 0.7

## Removal of special treatment to `Bijectors.TransformedDistribution`

Previously, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` only required the support of the target log-density problem to match that of `q`.
This was implemented by giving a special treatment to `q <: Bijectors.TransformedDistribution` through the `Bijectors` extension.
This, however, resulted in a multiplicative complexity in maintaining the relevant bits.
Since this is not the only way to deal with constrained supports, `Bijectors` extension is now removed.
In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` now expect an unconstrained target log-density problem.
Instead, a tutorial has been added to the documentation on how to deal with a target log-density problem with constrained support.

# Release 0.6

## New Algorithms
Expand Down
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.6"
version = "0.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -20,21 +20,18 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"]
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]

[compat]
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13, 0.14, 0.15"
ChainRulesCore = "1"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Expand All @@ -54,7 +51,6 @@ StatsBase = "0.32, 0.33, 0.34"
julia = "1.10, 1.11.2"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
77 changes: 57 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
[![Tests](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)

| AD Backend | Integration Status |
| ------------- | ------------- |
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |
| AD Backend | Integration Status |
|:---------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |

# AdvancedVI.jl

Expand Down Expand Up @@ -69,7 +69,7 @@ end;

Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`.
We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose.
This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017].
The bijector corresponding to the joint support of our model can be constructed as follows:

```julia
using Bijectors: Bijectors
Expand All @@ -85,6 +85,41 @@ end;

A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.

Since most VI algorithms assume that the posterior is unconstrained, we will apply a change-of-variable to our model to make it unconstrained.
This amounts to wrapping it into a `LogDensityProblem` that applies the transformation and the corresponding Jacobian adjustment.

```julia
struct TransformedLogDensityProblem{Prob,BInv}
prob::Prob
binv::BInv
end
Comment on lines +91 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is always the same, it seems like it might be worth putting in a library somewhere... although AdvancedVI isn't the right place for it... maybe Bijectors?

Copy link
Member

@penelopeysm penelopeysm Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, you might already be aware, but there is an infamously pathological DynamicPPL edge case with something like this:

@model function f()
    x ~ Normal()
    y ~ truncated(Normal(); lower = x)
end

(of course you don't need DynamicPPL to make a logdensityfunction like that, you could construct it with plain logpdf). The issue is that the bijector cannot be statically constructed, because the bijector for y will depend on the value of x. It's still a 1 to 1 function since the inputs fully determine the bijector which fully determines the output. But I think it might not fit nicely into the structure above. Probably you have to do the hardcoded logjac calculation way.

Copy link
Member Author

@Red-Portal Red-Portal Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is always the same, it seems like it might be worth putting in a library somewhere... although AdvancedVI isn't the right place for it... maybe Bijectors?

I think TransformedLogDensities.jl was supposed to serve this purpose, but not sure why it never decided to work Bijectors.

BTW, you might already be aware, but there is an infamously pathological DynamicPPL edge case with something like this:

How are the MCMC algorithms handling this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turing doesn't have a single bijector per model, instead there's one bijector per variable, which is constructed at runtime depending on the distribution on the rhs of tilde.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then I guess this is going to be a known issue for AdvancedVI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's an unsolvable problem though. I feel like if anything it should be fairly easy to support (don't hardcode a single bijector but rather use LogDensityFunction with linking).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not in this PR of course)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(don't hardcode a single bijector but rather use LogDensityFunction with linking).

That's what I am planning to do (in fact already implemented in the open PR downstream) for the Turing interface. But the problem is that the output q needs to be wrapped with a Bijectors.TransformedDistribution. I don't see any obvious way around this.


function TransformedLogDensityProblem(prob)
b = Bijectors.bijector(prob)
binv = Bijectors.inverse(b)
return TransformedLogDensityProblem{typeof(prob),typeof(binv)}(prob, binv)
end

function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans)
(; prob, binv) = prob_trans
θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(binv, θ_trans)
return LogDensityProblems.logdensity(prob, θ) + logabsdetjac
end

function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
(; prob, binv) = prob_trans
b = Bijectors.inverse(binv)
d = LogDensityProblems.dimension(prob)
return prod(Bijectors.output_size(b, (d,)))
end

function LogDensityProblems.capabilities(
::Type{TransformedLogDensityProblem{Prob,BInv}}
) where {Prob,BInv}
return LogDensityProblems.capabilities(Prob)
end;
```

For the dataset, we will use the popular [sonar classification dataset](https://archive.ics.uci.edu/dataset/151/connectionist+bench+sonar+mines+vs+rocks) from the UCI repository.
This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI/OpenML.jl).
The sonar dataset corresponds to the dataset id 40.
Expand All @@ -109,7 +144,8 @@ X = hcat(X, ones(size(X, 1)));
The model can now be instantiated as follows:

```julia
model = LogReg(X, y);
prob = LogReg(X, y);
prob_trans = TransformedLogDensityProblem(prob)
```

For the VI algorithm, we will use `KLMinRepGradDescent`:
Expand All @@ -136,37 +172,38 @@ For this, it is straightforward to use `LogDensityProblemsAD`:
using DifferentiationInterface: DifferentiationInterface
using LogDensityProblemsAD: LogDensityProblemsAD

model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model);
prob_trans_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);
```

For the variational family, we will consider a `FullRankGaussian` approximation:

```julia
using LinearAlgebra

d = LogDensityProblems.dimension(model_ad)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d)))
d = LogDensityProblems.dimension(prob_trans_ad)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d)))
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));
```

The bijector can now be applied to `q` to match the support of the target problem.
We can now run VI:

```julia
b = Bijectors.bijector(model)
binv = Bijectors.inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv);
max_iter = 10^3
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);
```

We can now run VI:
Recall that we applied a change-of-variable to the posterior to make it unconstrained.
This, however, is not the original constrained posterior that we wanted to approximate.
Therefore, we finally need to apply a change-of-variable to `q_opt` to make it approximate our original problem.

```julia
max_iter = 10^3
q, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;);
b = Bijectors.bijector(prob)
binv = Bijectors.inverse(b)
q_trans = Bijectors.TransformedDistribution(q_opt, binv)
```

For more examples and details, please refer to the documentation.

[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. PMLR.
[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. PMLR.
[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*.
[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of machine learning research*.
4 changes: 1 addition & 3 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -20,9 +19,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
AdvancedVI = "0.6"
AdvancedVI = "0.7"
BenchmarkTools = "1"
Bijectors = "0.13, 0.14, 0.15"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
Expand Down
40 changes: 30 additions & 10 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ADTypes
using AdvancedVI
using BenchmarkTools
using Bijectors
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
Expand All @@ -17,8 +16,34 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

include("normallognormal.jl")
include("unconstrdist.jl")
struct Dist{D<:ContinuousMultivariateDistribution}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content of unconstrdist.jl have been moved here.

dist::D
end

function LogDensityProblems.logdensity(model::Dist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.logdensity_and_gradient(model::Dist, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::Dist)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:Dist})
return LogDensityProblems.LogDensityOrder{0}()
end

function normal(; n_dims=10, realtype=Float64)
μ = fill(realtype(5), n_dims)
Σ = Diagonal(ones(realtype, n_dims))
return Dist(MvNormal(μ, Σ))
end

const SUITES = BenchmarkGroup()

Expand All @@ -33,10 +58,7 @@ end
begin
T = Float64

for (probname, prob) in [
("normal + bijector", normallognormal(; n_dims=10, realtype=T))
("normal", normal(; n_dims=10, realtype=T))
]
for (probname, prob) in [("normal", normal(; n_dims=10, realtype=T))]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
opt = Optimisers.Adam(T(1e-3))
Expand All @@ -59,9 +81,7 @@ begin
),
]

b = Bijectors.bijector(prob)
binv = inverse(b)
q = Bijectors.TransformedDistribution(family, binv)
q = family
alg = KLMinRepGradDescent(adtype; optimizer=opt, entropy, operator=ClipScale())

SUITES[probname][objname][familyname][adname] = begin
Expand Down
44 changes: 0 additions & 44 deletions bench/normallognormal.jl

This file was deleted.

33 changes: 0 additions & 33 deletions bench/unconstrdist.jl

This file was deleted.

2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
[compat]
ADTypes = "1"
Accessors = "0.1"
AdvancedVI = "0.6"
AdvancedVI = "0.7"
Bijectors = "0.13.6, 0.14, 0.15"
DataFrames = "1"
DifferentiationInterface = "0.7"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ makedocs(;
"Scaling to Large Datasets" => "tutorials/subsampling.md",
"Stan Models" => "tutorials/stan.md",
"Normalizing Flows" => "tutorials/flows.md",
"Dealing with Constrained Posteriors" => "tutorials/constrained.md",
],
"Algorithms" => [
"`KLMinRepGradDescent`" => "klminrepgraddescent.md",
Expand Down
Loading
Loading