Skip to content

Commit 8c657e0

Browse files
committed
revert changes to the README
1 parent c05e2f5 commit 8c657e0

File tree

1 file changed

+14
-54
lines changed

1 file changed

+14
-54
lines changed

README.md

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
[![Tests](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)
44
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)
55

6-
| AD Backend | Integration Status |
7-
|:---------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
8-
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
9-
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
10-
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
11-
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
12-
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |
6+
| AD Backend | Integration Status |
7+
| ------------- | ------------- |
8+
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
9+
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
10+
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
11+
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
12+
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |
1313

1414
# AdvancedVI.jl
1515

@@ -69,7 +69,7 @@ end;
6969

7070
Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`.
7171
We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose.
72-
The bijector corresponding to the joint support of our model can be constructed as follows:
72+
This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017].
7373

7474
```julia
7575
using Bijectors: Bijectors
@@ -85,36 +85,6 @@ end;
8585

8686
A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.
8787

88-
Since most VI algorithms assume that the posterior is unconstrained, we will apply a change-of-variable to our model to make it unconstrained.
89-
This amounts to wrapping it into a `LogDensityProblem` that applies the transformation and the corresponding Jacobian adjustment.
90-
91-
```julia
92-
struct TransformedLogDensityProblem{Prob,Trans}
93-
prob::Prob
94-
transform::Trans
95-
end
96-
97-
function TransformedLogDensityProblem(prob, transform)
98-
return TransformedLogDensityProblem{typeof(prob),typeof(transform)}(prob, transform)
99-
end
100-
101-
function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans)
102-
(; prob, transform) = prob_trans
103-
θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(transform, θ_trans)
104-
return LogDensityProblems.logdensity(prob, θ) + logabsdetjac
105-
end
106-
107-
function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
108-
return LogDensityProblems.dimension(prob_trans.prob)
109-
end
110-
111-
function LogDensityProblems.capabilities(
112-
::Type{TransformedLogDensityProblem{Prob,Trans}}
113-
) where {Prob,Trans}
114-
return LogDensityProblems.capabilities(Prob)
115-
end;
116-
```
117-
11888
For the dataset, we will use the popular [sonar classification dataset](https://archive.ics.uci.edu/dataset/151/connectionist+bench+sonar+mines+vs+rocks) from the UCI repository.
11989
This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI/OpenML.jl).
12090
The sonar dataset corresponds to the dataset id 40.
@@ -139,10 +109,7 @@ X = hcat(X, ones(size(X, 1)));
139109
The model can now be instantiated as follows:
140110

141111
```julia
142-
prob = LogReg(X, y);
143-
b = Bijectors.bijector(prob)
144-
binv = Bijectors.inverse(b)
145-
prob_trans = TransformedLogDensityProblem(prob, binv)
112+
model = LogReg(X, y);
146113
```
147114

148115
For the VI algorithm, we will use `KLMinRepGradDescent`:
@@ -169,16 +136,16 @@ For this, it is straightforward to use `LogDensityProblemsAD`:
169136
using DifferentiationInterface: DifferentiationInterface
170137
using LogDensityProblemsAD: LogDensityProblemsAD
171138

172-
prob_trans_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);
139+
model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model);
173140
```
174141

175142
For the variational family, we will consider a `FullRankGaussian` approximation:
176143

177144
```julia
178145
using LinearAlgebra
179146

180-
d = LogDensityProblems.dimension(prob_trans_ad)
181-
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d)))
147+
d = LogDensityProblems.dimension(model_ad)
148+
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d)))
182149
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));
183150
```
184151

@@ -194,19 +161,12 @@ We can now run VI:
194161

195162
```julia
196163
max_iter = 10^3
197-
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);
198-
```
199-
200-
Recall that we applied a change-of-variable to the posterior to make it unconstrained.
201-
This, however, is not the original constrained posterior that we wanted to approximate.
202-
Therefore, we finally need to apply a change-of-variable to `q_opt` to make it approximate our original problem.
203-
204-
```julia
205-
q_trans = Bijectors.TransformedDistribution(q_opt, binv)
164+
q, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;);
206165
```
207166

208167
For more examples and details, please refer to the documentation.
209168

210169
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. PMLR.
211170
[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. PMLR.
212171
[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*.
172+
[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of machine learning research*.

0 commit comments

Comments
 (0)