-
Notifications
You must be signed in to change notification settings - Fork 19
update the example in the README #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,99 +10,133 @@ | |||||||||||||||||
| The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. | ||||||||||||||||||
| For example, integrating `Turing` with `AdvancedVI.ADVI` only involves converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. | ||||||||||||||||||
|
|
||||||||||||||||||
| ## Examples | ||||||||||||||||||
| ## Basic Example | ||||||||||||||||||
|
|
||||||||||||||||||
| `AdvancedVI` works with differentiable models specified as a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl). | ||||||||||||||||||
| For example, for the normal-log-normal model: | ||||||||||||||||||
| We will describe a simple example to demonstrate the basic usage of `AdvancedVI`. | ||||||||||||||||||
| `AdvancedVI` works with differentiable models specified through the [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) interface. | ||||||||||||||||||
| Let's look at a basic logistic regression example with a hierarchical prior. | ||||||||||||||||||
| For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ and the response variables $y \in \{0, 1\}^n$, we assume the following data generating process: | ||||||||||||||||||
|
|
||||||||||||||||||
| $$ | ||||||||||||||||||
| \begin{aligned} | ||||||||||||||||||
| x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ | ||||||||||||||||||
| y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), | ||||||||||||||||||
| \sigma &\sim \text{Student-t}_{3}(0, 1) \\ | ||||||||||||||||||
| \beta &\sim \text{Normal}\left(0_d, \sigma \mathrm{I}_d\right) \\ | ||||||||||||||||||
| y &\sim \mathrm{Bernoulli}\left(X \beta\right) | ||||||||||||||||||
| \end{aligned} | ||||||||||||||||||
| $$ | ||||||||||||||||||
|
|
||||||||||||||||||
| a `LogDensityProblem` can be implemented as | ||||||||||||||||||
| The `LogDensityProblem` corresponding to this model can be constructed as | ||||||||||||||||||
|
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
| using LogDensityProblems | ||||||||||||||||||
| import LogDensityProblems | ||||||||||||||||||
| using Distributions | ||||||||||||||||||
| using FillArrays | ||||||||||||||||||
|
|
||||||||||||||||||
| struct NormalLogNormal{MX,SX,MY,SY} | ||||||||||||||||||
| μ_x::MX | ||||||||||||||||||
| σ_x::SX | ||||||||||||||||||
| μ_y::MY | ||||||||||||||||||
| Σ_y::SY | ||||||||||||||||||
| struct LogReg{XType,YType} | ||||||||||||||||||
| X::XType | ||||||||||||||||||
| y::YType | ||||||||||||||||||
| end | ||||||||||||||||||
|
|
||||||||||||||||||
| function LogDensityProblems.logdensity(model::NormalLogNormal, θ) | ||||||||||||||||||
| (; μ_x, σ_x, μ_y, Σ_y) = model | ||||||||||||||||||
| return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) | ||||||||||||||||||
| function LogDensityProblems.logdensity(model::LogReg, θ) | ||||||||||||||||||
| (; X, y) = model | ||||||||||||||||||
| d = size(X, 2) | ||||||||||||||||||
| β, σ = θ[1:size(X, 2)], θ[end] | ||||||||||||||||||
|
|
||||||||||||||||||
| logprior_β = logpdf(MvNormal(Zeros(d), σ*I), β) | ||||||||||||||||||
| logprior_σ = logpdf(truncated(TDist(3.0); lower=0), σ) | ||||||||||||||||||
|
|
||||||||||||||||||
| logit = X*β | ||||||||||||||||||
| loglike_y = sum(@. logpdf(BernoulliLogit(logit), y)) | ||||||||||||||||||
| return loglike_y + logprior_β + logprior_σ | ||||||||||||||||||
| end | ||||||||||||||||||
|
|
||||||||||||||||||
| function LogDensityProblems.dimension(model::NormalLogNormal) | ||||||||||||||||||
| return length(model.μ_y) + 1 | ||||||||||||||||||
| function LogDensityProblems.dimension(model::LogReg) | ||||||||||||||||||
| return size(model.X, 2) + 1 | ||||||||||||||||||
| end | ||||||||||||||||||
|
|
||||||||||||||||||
| function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) | ||||||||||||||||||
| function LogDensityProblems.capabilities(::Type{<:LogReg}) | ||||||||||||||||||
| return LogDensityProblems.LogDensityOrder{0}() | ||||||||||||||||||
| end | ||||||||||||||||||
| ``` | ||||||||||||||||||
|
|
||||||||||||||||||
| Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose. | ||||||||||||||||||
| Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose. | ||||||||||||||||||
| This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017]. | ||||||||||||||||||
|
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
| using Bijectors | ||||||||||||||||||
| import Bijectors | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| function Bijectors.bijector(model::NormalLogNormal) | ||||||||||||||||||
| (; μ_x, σ_x, μ_y, Σ_y) = model | ||||||||||||||||||
| function Bijectors.bijector(model::LogReg) | ||||||||||||||||||
| d = size(model.X, 2) | ||||||||||||||||||
| return Bijectors.Stacked( | ||||||||||||||||||
| Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), | ||||||||||||||||||
| [1:1, 2:(1 + length(μ_y))], | ||||||||||||||||||
| Bijectors.bijector.([MvNormal(Zeros(d), 1.0), truncated(TDist(3.0); lower=0)]), | ||||||||||||||||||
| [1:d, (d + 1):(d + 1)], | ||||||||||||||||||
| ) | ||||||||||||||||||
| end | ||||||||||||||||||
| ``` | ||||||||||||||||||
|
|
||||||||||||||||||
| A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. | ||||||||||||||||||
| A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. | ||||||||||||||||||
|
|
||||||||||||||||||
| Let us instantiate a random normal-log-normal model. | ||||||||||||||||||
| For the dataset, we will use the popular [sonar classification dataset](https://archive.ics.uci.edu/dataset/151/connectionist+bench+sonar+mines+vs+rocks) from the UCI repository. | ||||||||||||||||||
| This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI/OpenML.jl). | ||||||||||||||||||
| The sonar dataset corresponds to the dataset id 40. | ||||||||||||||||||
|
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
| using LinearAlgebra | ||||||||||||||||||
|
|
||||||||||||||||||
| n_dims = 10 | ||||||||||||||||||
| μ_x = randn() | ||||||||||||||||||
| σ_x = exp.(randn()) | ||||||||||||||||||
| μ_y = randn(n_dims) | ||||||||||||||||||
| σ_y = exp.(randn(n_dims)) | ||||||||||||||||||
| model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) | ||||||||||||||||||
| import OpenML | ||||||||||||||||||
| import DataFrames | ||||||||||||||||||
|
Comment on lines
+85
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| data = Array(DataFrames.DataFrame(OpenML.load(40))) | ||||||||||||||||||
| X = Matrix{Float64}(data[:, 1:(end - 1)]) | ||||||||||||||||||
| y = Vector{Bool}(data[:, end] .== "Mine") | ||||||||||||||||||
| ``` | ||||||||||||||||||
| Let's apply some basic pre-processing and add an intercept column: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| X = (X .- mean(X; dims=2)) ./ std(X; dims=2) | ||||||||||||||||||
| X = hcat(X, ones(size(X, 1))) | ||||||||||||||||||
| ``` | ||||||||||||||||||
| The model can now be instantiated as follows: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| model = LogReg(X, y) | ||||||||||||||||||
| ``` | ||||||||||||||||||
|
|
||||||||||||||||||
| We can perform VI with stochastic gradient descent (SGD) using reparameterization gradient estimates of the ELBO[^TL2014][^RMW2014][^KW2014] as follows: | ||||||||||||||||||
|
|
||||||||||||||||||
| For the VI algorithm, we will use the following: | ||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| using Optimisers | ||||||||||||||||||
| using ADTypes, ReverseDiff | ||||||||||||||||||
| using AdvancedVI | ||||||||||||||||||
|
|
||||||||||||||||||
| # ELBO maximization via stochastic gradient descent with the reparameterization gradient | ||||||||||||||||||
| alg = KLMinRepGradDescent(AutoReverseDiff()) | ||||||||||||||||||
| alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff()) | ||||||||||||||||||
| ``` | ||||||||||||||||||
| This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[^TL2014][^RMW2014][^KW2014]. | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on. | ||||||||||||||||||
|
|
||||||||||||||||||
| # Mean-field Gaussian variational family | ||||||||||||||||||
| d = LogDensityProblems.dimension(model) | ||||||||||||||||||
| q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))) | ||||||||||||||||||
| This `KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` has gradients. | ||||||||||||||||||
| For this, it is straightforward to use `LogDensityProblemsAD`: | ||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| import DifferentiationInterface | ||||||||||||||||||
| import LogDensityProblemsAD | ||||||||||||||||||
|
Comment on lines
+114
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| # Match support by applying the `model`'s inverse bijector | ||||||||||||||||||
| model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model) | ||||||||||||||||||
| ``` | ||||||||||||||||||
|
|
||||||||||||||||||
| For the variational family, we will consider a `FullRankGaussian` approximation: | ||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| using LinearAlgebra | ||||||||||||||||||
|
|
||||||||||||||||||
| d = LogDensityProblems.dimension(model_ad) | ||||||||||||||||||
| q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))) | ||||||||||||||||||
| ``` | ||||||||||||||||||
| The bijector can now be applied to `q` to match the support of the target problem. | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| b = Bijectors.bijector(model) | ||||||||||||||||||
| binv = inverse(b) | ||||||||||||||||||
| binv = Bijectors.inverse(b) | ||||||||||||||||||
| q_transformed = Bijectors.TransformedDistribution(q, binv) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Run inference | ||||||||||||||||||
| ``` | ||||||||||||||||||
| We can now run VI: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| ```julia | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| max_iter = 10^3 | ||||||||||||||||||
| q_avg, info, _ = AdvancedVI.optimize( | ||||||||||||||||||
| alg, | ||||||||||||||||||
| max_iter, | ||||||||||||||||||
| model, | ||||||||||||||||||
| model_ad, | ||||||||||||||||||
| q_transformed; | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
136
to
141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
| ``` | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶