Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
3506f4f
improve example in the README
Red-Portal Aug 20, 2025
ba99404
update basic example
Red-Portal Aug 20, 2025
f9ab239
add Stan models example
Red-Portal Aug 20, 2025
e39e0d6
run formatter, add missing dependency
Red-Portal Aug 20, 2025
d0c13f3
add compat for DataFrames
Red-Portal Aug 20, 2025
7a3dbbd
apply formatter to README
Red-Portal Aug 20, 2025
7ef6152
fix missing dependency
Red-Portal Aug 20, 2025
51b495b
add missing `LogDensityProblemsAD` dependency
Red-Portal Aug 20, 2025
3039008
add missing JSON dependency
Red-Portal Aug 20, 2025
70745a3
fix typo in basic tutorials
Red-Portal Aug 21, 2025
f940e2e
fix typo in docs section name
Red-Portal Aug 21, 2025
f4d8fb6
fix Stan tutorial
Red-Portal Aug 21, 2025
b9a7fb3
run formatter
Red-Portal Aug 21, 2025
b7fbc5f
tune the figures in docs
Red-Portal Aug 21, 2025
fb03469
Merge branch 'add_basic_tutorials' of github.com:TuringLang/AdvancedV…
Red-Portal Aug 21, 2025
df84f9a
update doc plot scale for stan example
Red-Portal Aug 21, 2025
e801817
update doc tweak plot in basic example
Red-Portal Aug 21, 2025
8fb3288
update docs tweak plots
Red-Portal Aug 21, 2025
1baf210
add normalizing flows dependency
Red-Portal Aug 22, 2025
49d3ffe
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into add_b…
Red-Portal Aug 22, 2025
39fad54
add normalizing flow example
Red-Portal Aug 22, 2025
7963017
runf formatter
Red-Portal Aug 22, 2025
0d204d3
fix formatting a bit
Red-Portal Aug 22, 2025
5899b9d
fix maybe improve numerical stability
Red-Portal Aug 22, 2025
090acbc
minor fix
Red-Portal Aug 23, 2025
7cadff8
fix typo in flow tutorial
Red-Portal Aug 25, 2025
0455006
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into add_b…
Red-Portal Aug 27, 2025
b88b463
update docs use MacKay's approximation for logistic predictive
Red-Portal Aug 27, 2025
d56d406
fix error in docs basic example
Red-Portal Aug 27, 2025
0ac49da
fix docs prior hyperparameter in basic example
Red-Portal Aug 27, 2025
50cd819
add doubly stochastic VI
Red-Portal Aug 27, 2025
140f3af
run formatter
Red-Portal Aug 27, 2025
36745f6
fix typos
Red-Portal Aug 27, 2025
8198abc
fix formatting
Red-Portal Aug 27, 2025
f479ddb
add missing dependency
Red-Portal Aug 27, 2025
8ca9d54
fix docs more stable initialization for stan example
Red-Portal Aug 27, 2025
c445d0d
fix docs try using julia lts for building docs
Red-Portal Aug 27, 2025
fa2985a
add additional details for bijectors in basic example
Red-Portal Aug 27, 2025
14b75bc
update tweak plot for flow example in docs
Red-Portal Aug 27, 2025
bda40c3
fix tweak plots in docs
Red-Portal Aug 27, 2025
6fb8af4
fix add missing `nothing`s in flow example in docs
Red-Portal Aug 27, 2025
adaefc0
fix docs examples to use a more stable initialization
Red-Portal Aug 28, 2025
c8de52e
update docs example for subsampling
Red-Portal Aug 28, 2025
8e873cd
fix use of flows in docs example
Red-Portal Aug 28, 2025
c50086e
run formatter
Red-Portal Aug 28, 2025
cf829f2
add labels to plotlines in subsampling docs
Red-Portal Aug 28, 2025
c6cbc69
fix tweak plot in basic example docs
Red-Portal Aug 28, 2025
22953fe
update docs example on subsampling
Red-Portal Aug 28, 2025
b151635
fix docs example flow realnvp interface
Red-Portal Aug 28, 2025
8b7b545
fix flows config in flow doc example
Red-Portal Aug 28, 2025
f9b87d1
add missing imports for `mean` in docs example
Red-Portal Aug 28, 2025
3494495
update docs example on subsampling
Red-Portal Aug 28, 2025
a6bee5c
fix init in examples to be principled s.t. P(-2 < x < 2) > 0.998
Red-Portal Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ an

$$
\begin{aligned}
\sigma &\sim \text{Student-t}_{3}(0, 1) \\
\sigma &\sim \text{LogNormal}(0, 1) \\
\beta &\sim \text{Normal}\left(0_d, \sigma \mathrm{I}_d\right) \\
y &\sim \mathrm{BernoulliLogit}\left(X \beta\right)
\end{aligned}
Expand All @@ -28,7 +28,7 @@ $$
The `LogDensityProblem` corresponding to this model can be constructed as

```julia
import LogDensityProblems
using LogDensityProblems: LogDensityProblems
using Distributions
using FillArrays

Expand All @@ -43,10 +43,10 @@ function LogDensityProblems.logdensity(model::LogReg, θ)
β, σ = θ[1:size(X, 2)], θ[end]

logprior_β = logpdf(MvNormal(Zeros(d), σ*I), β)
logprior_σ = logpdf(truncated(TDist(3.0); lower=0), σ)
logprior_σ = logpdf(LogNormal(0, 3), σ)

logit = X*β
loglike_y = sum(@. logpdf(BernoulliLogit(logit), y))
loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y)
return loglike_y + logprior_β + logprior_σ
end

Expand All @@ -56,23 +56,23 @@ end

function LogDensityProblems.capabilities(::Type{<:LogReg})
return LogDensityProblems.LogDensityOrder{0}()
end
end;
```

Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`.
Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`.
We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose.
This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017].

```julia
import Bijectors
using Bijectors: Bijectors

function Bijectors.bijector(model::LogReg)
d = size(model.X, 2)
return Bijectors.Stacked(
Bijectors.bijector.([MvNormal(Zeros(d), 1.0), truncated(TDist(3.0); lower=0)]),
Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]),
[1:d, (d + 1):(d + 1)],
)
end
end;
```

A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.
Expand All @@ -82,63 +82,72 @@ This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI
The sonar dataset corresponds to the dataset id 40.

```julia
import OpenML
import DataFrames
using OpenML: OpenML
using DataFrames: DataFrames
data = Array(DataFrames.DataFrame(OpenML.load(40)))
X = Matrix{Float64}(data[:, 1:(end - 1)])
y = Vector{Bool}(data[:, end] .== "Mine")
y = Vector{Bool}(data[:, end] .== "Mine");
```

Let's apply some basic pre-processing and add an intercept column:

```julia
X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)))
X = hcat(X, ones(size(X, 1)));
```

The model can now be instantiated as follows:

```julia
model = LogReg(X, y)
model = LogReg(X, y);
```

For the VI algorithm, we will use the following:
For the VI algorithm, we will use `KLMinRepGradDescent`:

```julia
using ADTypes, ReverseDiff
using AdvancedVI

alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff())
alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff());
```

This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[^TL2014][^RMW2014][^KW2014].
This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on.

This `KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` has gradients.
For this, it is straightforward to use `LogDensityProblemsAD`:
`KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` is differentiable.
If the `LogDensityProblem` has a differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities) of at least first-order, we can take advantage of this.
For this example, we will use `LogDensityProblemsAD` to equip our problem with a first-order capability:

```julia
import DifferentiationInterface
import LogDensityProblemsAD
using DifferentiationInterface: DifferentiationInterface
using LogDensityProblemsAD: LogDensityProblemsAD

model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model)
model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model);
```

For the variational family, we will consider a `FullRankGaussian` approximation:

```julia
using LinearAlgebra

d = LogDensityProblems.dimension(model_ad)
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)))
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(I, d, d)))
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));
```

The bijector can now be applied to `q` to match the support of the target problem.

```julia
b = Bijectors.bijector(model)
binv = Bijectors.inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)
q_transformed = Bijectors.TransformedDistribution(q, binv);
```

We can now run VI:

```julia
max_iter = 10^3
q_avg, info, _ = AdvancedVI.optimize(
alg,
max_iter,
model_ad,
q_transformed;
)
q, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;);
```

For more examples and details, please refer to the documentation.
Expand Down
12 changes: 12 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,41 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
OpenML = "8b6db2d4-7670-4922-a472-f9537c81ab66"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StanLogDensityProblems = "a545de4d-8dba-46db-9d34-4e41d3f07807"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "1"
AdvancedVI = "0.5, 0.4"
Bijectors = "0.13.6, 0.14, 0.15"
DataFrames = "1"
DifferentiationInterface = "0.7"
Distributions = "0.25"
Documenter = "1"
FillArrays = "1"
ForwardDiff = "0.10, 1"
JSON = "0.21"
LogDensityProblems = "2.1.1"
LogDensityProblemsAD = "1"
OpenML = "0.3"
Optimisers = "0.3, 0.4"
Plots = "1"
QuasiMonteCarlo = "0.3"
ReverseDiff = "1"
StanLogDensityProblems = "0.1"
StatsFuns = "1"
julia = "1.10, 1.11.2"
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ makedocs(;
"AdvancedVI" => "index.md",
"General Usage" => "general.md",
"Tutorials" => [
"tutorials/basic.md",
"Basic Example" => "tutorials/basic.md",
"Ussage with Stan" => "tutorials/stan.md",
],
"Algorithms" => [
"KLMinRepGradDescent" => "paramspacesgd/klminrepgraddescent.md",
Expand Down
3 changes: 1 addition & 2 deletions docs/src/paramspacesgd/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ q_transformed = Bijectors.TransformedDistribution(q, binv)
```

By passing `q_transformed` to `optimize`, the Jacobian adjustment for the bijector `b` is automatically applied.
(See [Examples](@ref examples) for a fully working example.)
(See the [Basic Example](@ref basic) for a fully working example.)

[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of Machine Learning Research*.
[^DLTBV2017]: Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv.
Expand Down Expand Up @@ -177,7 +177,6 @@ function Bijectors.bijector(model::NormalLogNormal)
end
```

Let us come back to the example in [Examples](@ref examples), where a `LogDensityProblem` is given as `model`.
In this example, the true posterior is contained within the variational family.
This setting is known as "perfect variational family specification."
In this case, the `RepGradELBO` estimator with `StickingTheLandingEntropy` is the only estimator known to converge exponentially fast ("linear convergence") to the true solution.
Expand Down
Loading
Loading