diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index d18d66c6..297438d7 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -24,6 +24,8 @@ jobs: steps: - name: Build and deploy Documenter.jl docs uses: TuringLang/actions/DocsDocumenter@main + with: + julia-version: 'lts' - name: Run doctests shell: julia --project=docs --color=yes {0} diff --git a/README.md b/README.md index e5cf99ce..c53de30f 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ an $$ \begin{aligned} -\sigma &\sim \text{Student-t}_{3}(0, 1) \\ -\beta &\sim \text{Normal}\left(0_d, \sigma \mathrm{I}_d\right) \\ +\sigma &\sim \text{LogNormal}(0, 3) \\ +\beta &\sim \text{Normal}\left(0_d, \sigma^2 \mathrm{I}_d\right) \\ y &\sim \mathrm{BernoulliLogit}\left(X \beta\right) \end{aligned} $$ @@ -31,7 +31,7 @@ $$ The `LogDensityProblem` corresponding to this model can be constructed as ```julia -import LogDensityProblems +using LogDensityProblems: LogDensityProblems using Distributions using FillArrays @@ -45,11 +45,11 @@ function LogDensityProblems.logdensity(model::LogReg, θ) d = size(X, 2) β, σ = θ[1:size(X, 2)], θ[end] - logprior_β = logpdf(MvNormal(Zeros(d), σ*I), β) - logprior_σ = logpdf(truncated(TDist(3.0); lower=0), σ) + logprior_β = logpdf(MvNormal(Zeros(d), σ), β) + logprior_σ = logpdf(LogNormal(0, 3), σ) logit = X*β - loglike_y = sum(@. logpdf(BernoulliLogit(logit), y)) + loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y) return loglike_y + logprior_β + logprior_σ end @@ -59,23 +59,23 @@ end function LogDensityProblems.capabilities(::Type{<:LogReg}) return LogDensityProblems.LogDensityOrder{0}() -end +end; ``` -Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. +Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose. This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017]. ```julia -import Bijectors +using Bijectors: Bijectors function Bijectors.bijector(model::LogReg) d = size(model.X, 2) return Bijectors.Stacked( - Bijectors.bijector.([MvNormal(Zeros(d), 1.0), truncated(TDist(3.0); lower=0)]), + Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]), [1:d, (d + 1):(d + 1)], ) -end +end; ``` A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. @@ -85,63 +85,74 @@ This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI The sonar dataset corresponds to the dataset id 40. ```julia -import OpenML -import DataFrames +using OpenML: OpenML +using DataFrames: DataFrames data = Array(DataFrames.DataFrame(OpenML.load(40))) X = Matrix{Float64}(data[:, 1:(end - 1)]) -y = Vector{Bool}(data[:, end] .== "Mine") +y = Vector{Bool}(data[:, end] .== "Mine"); ``` + Let's apply some basic pre-processing and add an intercept column: + ```julia +using Statistics + X = (X .- mean(X; dims=2)) ./ std(X; dims=2) -X = hcat(X, ones(size(X, 1))) +X = hcat(X, ones(size(X, 1))); ``` + The model can now be instantiated as follows: + ```julia -model = LogReg(X, y) +model = LogReg(X, y); ``` -For the VI algorithm, we will use the following: +For the VI algorithm, we will use `KLMinRepGradDescent`: + ```julia using ADTypes, ReverseDiff using AdvancedVI -alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff()) +alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff()); ``` + This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[^TL2014][^RMW2014][^KW2014]. This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on. -This `KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` has gradients. -For this, it is straightforward to use `LogDensityProblemsAD`: +`KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` is differentiable. +If the `LogDensityProblem` has a differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities) of at least first-order, we can take advantage of this. +For this example, we will use `LogDensityProblemsAD` to equip our problem with a first-order capability: + ```julia -import DifferentiationInterface -import LogDensityProblemsAD +using DifferentiationInterface: DifferentiationInterface +using LogDensityProblemsAD: LogDensityProblemsAD -model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model) +model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model); ``` For the variational family, we will consider a `FullRankGaussian` approximation: + ```julia using LinearAlgebra d = LogDensityProblems.dimension(model_ad) -q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))) +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d))) +q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))); ``` + The bijector can now be applied to `q` to match the support of the target problem. + ```julia b = Bijectors.bijector(model) binv = Bijectors.inverse(b) -q_transformed = Bijectors.TransformedDistribution(q, binv) +q_transformed = Bijectors.TransformedDistribution(q, binv); ``` + We can now run VI: + ```julia max_iter = 10^3 -q_avg, info, _ = AdvancedVI.optimize( - alg, - max_iter, - model_ad, - q_transformed; -) +q, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;); ``` For more examples and details, please refer to the documentation. diff --git a/docs/Project.toml b/docs/Project.toml index 62995f5c..d8185fc8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,30 +1,48 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" +OpenML = "8b6db2d4-7670-4922-a472-f9537c81ab66" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StanLogDensityProblems = "a545de4d-8dba-46db-9d34-4e41d3f07807" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ADTypes = "1" +Accessors = "0.1" AdvancedVI = "0.5, 0.4" Bijectors = "0.13.6, 0.14, 0.15" +DataFrames = "1" +DifferentiationInterface = "0.7" Distributions = "0.25" Documenter = "1" FillArrays = "1" ForwardDiff = "0.10, 1" +Functors = "0.5" +JSON = "0.21" LogDensityProblems = "2.1.1" +LogDensityProblemsAD = "1" +NormalizingFlows = "0.2.2" +OpenML = "0.3" Optimisers = "0.3, 0.4" Plots = "1" QuasiMonteCarlo = "0.3" ReverseDiff = "1" +StanLogDensityProblems = "0.1" StatsFuns = "1" julia = "1.10, 1.11.2" diff --git a/docs/make.jl b/docs/make.jl index 315a82c8..a3ae15bc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,7 +17,10 @@ makedocs(; "AdvancedVI" => "index.md", "General Usage" => "general.md", "Tutorials" => [ - "tutorials/basic.md", + "Basic Example" => "tutorials/basic.md", + "Scaling to Large Datasets" => "tutorials/subsampling.md", + "Stan Models" => "tutorials/stan.md", + "Normalizing Flows" => "tutorials/flows.md", ], "Algorithms" => [ "KLMinRepGradDescent" => "paramspacesgd/klminrepgraddescent.md", diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index ea9bee84..04db2ff5 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -82,7 +82,7 @@ q_transformed = Bijectors.TransformedDistribution(q, binv) ``` By passing `q_transformed` to `optimize`, the Jacobian adjustment for the bijector `b` is automatically applied. -(See [Examples](@ref examples) for a fully working example.) +(See the [Basic Example](@ref basic) for a fully working example.) [^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of Machine Learning Research*. [^DLTBV2017]: Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv. @@ -177,7 +177,6 @@ function Bijectors.bijector(model::NormalLogNormal) end ``` -Let us come back to the example in [Examples](@ref examples), where a `LogDensityProblem` is given as `model`. In this example, the true posterior is contained within the variational family. This setting is known as "perfect variational family specification." In this case, the `RepGradELBO` estimator with `StickingTheLandingEntropy` is the only estimator known to converge exponentially fast ("linear convergence") to the true solution. diff --git a/docs/src/tutorials/basic.md b/docs/src/tutorials/basic.md index 634de8da..1c0dfa13 100644 --- a/docs/src/tutorials/basic.md +++ b/docs/src/tutorials/basic.md @@ -1,153 +1,290 @@ -## [Evidence Lower Bound Maximization](@id examples) +# [Basic Example](@id basic) -In this tutorial, we will work with a `normal-log-normal` model. +In this tutorial, we will demonstrate the basic usage of `AdvancedVI` with [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) interface. + +## Problem Setup + +Let's consider a basic logistic regression example with a hierarchical prior. +For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ and the response variables $y \in \{0, 1\}^n$, we assume the following data generating process: ```math \begin{aligned} -x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +\sigma &\sim \text{LogNormal}(0, 3) \\ +\beta &\sim \text{Normal}\left(0_d, \sigma^2 \mathrm{I}_d\right) \\ +y &\sim \mathrm{BernoulliLogit}\left(X \beta\right) \end{aligned} ``` -BBVI with `Bijectors.Exp` bijectors is able to infer this model exactly. +The `LogDensityProblem` corresponding to this model can be constructed as -Using the `LogDensityProblems` interface, the model can be defined as follows: +```@example basic +using LogDensityProblems: LogDensityProblems +using Distributions +using FillArrays -```@example elboexample -using LogDensityProblems -using ForwardDiff - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY +struct LogReg{XType,YType} + X::XType + y::YType end -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model - return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +function LogDensityProblems.logdensity(model::LogReg, θ) + (; X, y) = model + d = size(X, 2) + β, σ = θ[1:d], θ[end] + + logprior_β = logpdf(MvNormal(Zeros(d), σ), β) + logprior_σ = logpdf(LogNormal(0, 3), σ) + + logit = X*β + loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y) + return loglike_y + logprior_β + logprior_σ end -function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) - return ( - LogDensityProblems.logdensity(model, θ), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), - ) +function LogDensityProblems.dimension(model::LogReg) + return size(model.X, 2) + 1 end -function LogDensityProblems.dimension(model::NormalLogNormal) - return length(model.μ_y) + 1 +function LogDensityProblems.capabilities(::Type{<:LogReg}) + return LogDensityProblems.LogDensityOrder{0}() end +nothing +``` -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{1}() +Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. +We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose. +This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017]. + +In our case, we need a bijector that applies an identity map for the first `size(X,2)` coordinates, and map the last coordinate to the support of `LogNormal(0, 3)`. +This can be done as follows: + +[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of machine learning research*. +```@example basic +using Bijectors: Bijectors + +function Bijectors.bijector(model::LogReg) + d = size(model.X, 2) + return Bijectors.Stacked( + Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]), + [1:d, (d + 1):(d + 1)], + ) end +nothing ``` -Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). -The required order of differentiation capability will vary depending on the VI algorithm. -In this example, we will use `KLMinRepGradDescent`, which requires first-order capability. +For more details, please refer to the documentation of [`Bijectors`](https://github.com/TuringLang/Bijectors.jl). -Let's now instantiate the model +For the dataset, we will use the popular [sonar classification dataset](https://archive.ics.uci.edu/dataset/151/connectionist+bench+sonar+mines+vs+rocks) from the UCI repository. +This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI/OpenML.jl). +The sonar dataset corresponds to the dataset id 40. -```@example elboexample -using LinearAlgebra +```@example basic +using OpenML: OpenML +using DataFrames: DataFrames +data = Array(DataFrames.DataFrame(OpenML.load(40))) +X = Matrix{Float64}(data[:, 1:(end - 1)]) +y = Vector{Bool}(data[:, end] .== "Mine") +nothing +``` + +Let's apply some basic pre-processing and add an intercept column: + +```@example basic +using Statistics + +X = (X .- mean(X; dims=2)) ./ std(X; dims=2) +X = hcat(X, ones(size(X, 1))) +nothing +``` + +The model can now be instantiated as follows: -n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); +```@example basic +model = LogReg(X, y) nothing ``` -Let's now load `AdvancedVI`. -In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation. -Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`. -(This does not need to be the same as the AD backend used for the first-order capability of `model`.) -The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. -Here, we will use `ReverseDiff`, which can be selected by later passing `ADTypes.AutoReverseDiff()`. +## Basic Usage -```@example elboexample +For the VI algorithm, we will use `KLMinRepGradDescent`: + +```@example basic using ADTypes, ReverseDiff using AdvancedVI -alg = KLMinRepGradDescent(AutoReverseDiff()); +alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff()) nothing ``` -Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support. -Since `x` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. -Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. +This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[^TL2014][^RMW2014][^KW2014]. +This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on. +`KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` is differentiable. +If the `LogDensityProblem` has a differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities) of at least first-order, we can take advantage of this. -```@example elboexample -using Bijectors +For this example, we will use `LogDensityProblemsAD` to equip our problem with a first-order capability: -function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model - return Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:(1 + length(μ_y))], - ) -end +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. PMLR. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. PMLR. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. +```@example basic +using DifferentiationInterface: DifferentiationInterface +using LogDensityProblemsAD: LogDensityProblemsAD + +model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model) +nothing +``` -b = Bijectors.bijector(model); -binv = inverse(b) +For the variational family, we will consider a `FullRankGaussian` approximation: + +```@example basic +using LinearAlgebra + +d = LogDensityProblems.dimension(model_ad) +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d))) nothing ``` -For the variational family, we will use the classic mean-field Gaussian family. +The bijector can now be applied to `q` to match the support of the target problem. -```@example elboexample -d = LogDensityProblems.dimension(model); -μ = randn(d); -L = Diagonal(ones(d)); -q0 = AdvancedVI.MeanFieldGaussian(μ, L) +```@example basic +b = Bijectors.bijector(model) +binv = Bijectors.inverse(b) +q_transformed = Bijectors.TransformedDistribution(q, binv) nothing ``` -And then, we now apply the bijector to the variational family. +We can now run VI: -```@example elboexample -q0_trans = Bijectors.TransformedDistribution(q0, binv) +```@example basic +max_iter = 10^4 +q_out, info, _ = AdvancedVI.optimize( + alg, max_iter, model_ad, q_transformed; show_progress=false +) nothing ``` -Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. +Let's verify that the optimization procedure converged. +For this, we will visually inspect that the maximization objective of `KLMinRepGradDescent`, the "evidence lower bound" (ELBO) increased. +Since `KLMinRepGradDescent` stores the ELBO estimate at each iteration in `info`, we can visualize this as follows: + +```@example basic +using Plots -```@example elboexample -n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model, q0_trans; show_progress=false); +plot( + [i.iteration for i in info], + [i.elbo for i in info]; + xlabel="Iteration", + ylabel="ELBO", + label=nothing, +) +savefig("basic_example_elbo.svg") nothing ``` -`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family. -For more information see [this section](@ref clipscale). +![](basic_example_elbo.svg) -`q_out` is the final output of the optimization procedure. -If a parameter averaging strategy is used through the keyword argument `averager`, `q_out` will be the output of the averaging strategy. +## Custom Callback -The selected inference procedure stores per-iteration statistics into `stats`. -For instance, the ELBO can be plotted as follows: +The ELBO estimates above however, use only a handful of Monte Carlo samples. +Furthermore, the ELBO is evaluated on the iterates of the optimization procedure, which may not coincide with the actual output of the algorithm. +(For instance, if parameter averaging is used.) +Therefore, we may want to occasionally estimate higher resolution ELBO estimates. +Also, depending on the problem, we may want to monitor some problem-specific diagnostics for monitoring the progress. -```@example elboexample -using Plots +For both use cases above, defining a custom `callback` function can be useful. +In this example, we will compute a more accurate estimate of the ELBO and the classification accuracy every `logging_interval = 10` iterations. + +```@example basic +using StatsFuns: StatsFuns + +""" + logistic_prediction(X, μ_β, Σ_β) + +Approximate the posterior predictive probability for a logistic link function using Mackay's approximation (Bishop p. 220). +""" +function logistic_prediction(X, μ_β, Σ_β) + xtΣx = sum((model.X*Σ_β) .* model.X; dims=2)[:, 1] + κ = @. 1/sqrt(1 + π/8*xtΣx) + return StatsFuns.logistic.(κ .* X*μ_β) +end + +logging_interval = 100 +function callback(; iteration, averaged_params, restructure, kwargs...) + if mod(iteration, logging_interval) == 1 + + # Use the averaged parameters (the eventual output of the algorithm) + q_avg = restructure(averaged_params) + + # Compute predictions + μ_β = mean(q_avg.dist)[1:(end - 1)] # posterior mean of β + Σ_β = cov(q_avg.dist)[1:(end - 1), end - 1] # marginal posterior covariance of β + y_pred = logistic_prediction(X, μ_β, Σ_β) .> 0.5 + + # Prediction accuracy + acc = mean(y_pred .== model.y) + + # Higher fidelity estimate of the ELBO on the averaged parameters + n_samples = 256 + obj = AdvancedVI.RepGradELBO(n_samples) + elbo_callback = estimate_objective(obj, q_avg, model) + + (elbo_callback=elbo_callback, accuracy=acc) + else + nothing + end +end +nothing +``` + +Note that the interface for the callback function will depend on the VI algorithm being used. +Therefore, please refer to the documentation of each VI algorithm. -t = [i.iteration for i in info] -y = [i.elbo for i in info] -plot(t, y; label="BBVI", xlabel="Iteration", ylabel="ELBO") -savefig("bbvi_example_elbo.svg") +The `callback` can be supplied to `optimize`: + +```@example basic +max_iter = 10^4 +q_out, info, _ = AdvancedVI.optimize( + alg, max_iter, model_ad, q_transformed; show_progress=false, callback=callback +) nothing ``` -![](bbvi_example_elbo.svg) +First, let's compare the default estimate of the ELBO, which uses a small number of samples and is evaluated in the current iterate, versus the ELBO computed in the callback, which uses a large number of samples and is evaluated on the averaged iterate. + +```@example basic +t = 1:max_iter +elbo = [i.elbo for i in info[t]] -Further information can be gathered by defining your own `callback!`. +t_callback = 1:logging_interval:max_iter +elbo_callback = [i.elbo_callback for i in info[t_callback]] -The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: +plot(t, elbo; xlabel="Iteration", ylabel="ELBO", label="Default") +plot!(t_callback, elbo_callback; label="Callback", ylims=(-300, Inf), linewidth=2) -```@example elboexample -estimate_objective(RepGradELBO(10^4), q_out, model) +savefig("basic_example_elbo_callback.svg") +nothing ``` + +![](basic_example_elbo_callback.svg) + +We can see that the default ELBO estimates are noisy compared to the higher fidelity estimates from the callback. +After a few thousands of iterations, it is difficult to judge if we are still making progress or not. +In contrast, the estimates from callback show that the objective is increasing smoothly. + +Similarly, we can monitor the evolution of the prediction accuracy. + +```@example basic +acc_callback = [i.accuracy for i in info[t_callback]] +plot( + t_callback, + acc_callback; + xlabel="Iteration", + ylabel="Prediction Accuracy", + label=nothing, +) +savefig("basic_example_acc.svg") +nothing +``` + +![](basic_example_acc.svg) + +Clearly, the accuracy is improving over time. diff --git a/docs/src/tutorials/flows.md b/docs/src/tutorials/flows.md new file mode 100644 index 00000000..c1a646d0 --- /dev/null +++ b/docs/src/tutorials/flows.md @@ -0,0 +1,257 @@ +# Normalizing Flows + +In this example, we will see how to use [`NormalizingFlows`](https://github.com/TuringLang/NormalizingFlows.jl) with `AdvancedVI`. + +## Problem Setup + +For the problem, we will look into a toy problem where `NormalizingFlows` can be benficial. +For a dataset of real valued data $y_1, \ldots, y_n$, consider the following generative model: + +```math +\begin{aligned} +\alpha &\sim \text{LogNormal}(0, 1) \\ +\beta &\sim \text{Normal}\left(0, 10\right) \\ +y_i &\sim \text{Normal}\left(\alpha \beta, 1\right) +\end{aligned} +``` + +Notice that the mean is predicted as the product $\alpha \beta$ of two unknown parameters. +This results in [multiplicative unidentifiability](https://betanalpha.github.io/assets/case_studies/identifiability.html#53_Multiplicative_Degeneracy) of $\alpha$ and $\beta$. +As such, the posterior exhibits a "banana"-shaped degeneracy. +Multiplicative degeneracy is not entirely made up and do come up in some models used in practice. +For example, in the 3-parameter (3-PL) item-response theory model and the N-mixture model used for estimating animal population. + +```@example flow +using Bijectors: Bijectors +using Distributions +using LogDensityProblems: LogDensityProblems + +struct MultDegen{Y} + y::Y +end + +function LogDensityProblems.logdensity(model::MultDegen, θ) + α, β = θ[1], θ[2] + + logprior_α = logpdf(LogNormal(0, 1), α) + logprior_β = logpdf(Normal(0, 10), β) + + loglike_y = mapreduce(+, model.y) do yi + logpdf(Normal(α * β, 1.0), yi) + end + return logprior_α + logprior_β + loglike_y +end + +function LogDensityProblems.dimension(model::MultDegen) + return 2 +end + +function LogDensityProblems.capabilities(::Type{<:MultDegen}) + return LogDensityProblems.LogDensityOrder{0}() +end +nothing +``` + +Degenerate posteriors often indicate that there is not enough data to pin-point the right set of parameters. +Therefore, for the purpose of illustration, we will use a single data point: + +```@example flow +model = MultDegen([3.0]) +nothing +``` + +The banana-shaped degeneracy of the posterior can be readily visualized: + +```@example flow +using Plots + +contour( + range(0, 4; length=64), + range(-3, 25; length=64), + (x, y) -> LogDensityProblems.logdensity(model, [x, y]); + xlabel="α", + ylabel="β", + clims=(-8, Inf), +) + +savefig("flow_example_posterior.svg") +nothing +``` + +![](flow_example_posterior.svg) + +Notice that the two ends of the "banana" run deep both horizontally and vertically. +This sort of nonlinear correlation structure is difficult to model using only location-scale distributions. + +## Gaussian Variational Family + +As usual, let's try to fit a multivariate Gaussian to this posterior. + +```@example flow +using ADTypes: ADTypes +using ReverseDiff: ReverseDiff +using DifferentiationInterface: DifferentiationInterface +using LogDensityProblemsAD: LogDensityProblemsAD + +model_ad = LogDensityProblemsAD.ADgradient( + ADTypes.AutoReverseDiff(; compile=true), model; x=[1.0, 1.0] +) +nothing +``` + +Since $\alpha$ is constrained to the positive real half-space, we have to employ bijectors. +For this, we use [Bijectors](https://github.com/TuringLang/Bijectors.jl): + +```@example flow +using Bijectors: Bijectors + +function Bijectors.bijector(model::MultDegen) + return Bijectors.Stacked( + Bijectors.bijector.([LogNormal(0, 1), Normal(0, 10)]), [1:1, 2:2] + ) +end +nothing +``` + +For the algorithm, we will use the `KLMinRepGradProxDescent` objective. + +```@example flow +using AdvancedVI +using LinearAlgebra + +d = LogDensityProblems.dimension(model_ad) +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(I, d, d))) + +binv = Bijectors.inverse(Bijectors.bijector(model)) +q_trans = Bijectors.TransformedDistribution(q, binv) + +max_iter = 3*10^3 +alg = KLMinRepGradProxDescent(ADTypes.AutoReverseDiff(; compile=true)) +q_out, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_trans; show_progress=false) +nothing +``` + +The resulting variational posterior can be visualized as follows: + +```@example flow +samples = rand(q_out, 10000) +histogram2d( + samples[1, :], + samples[2, :]; + normalize=:pdf, + nbins=32, + xlabel="α", + ylabel="β", + xlims=(0, 4), + ylims=(-3, 25), +) +savefig("flow_example_locationscale.svg") +nothing +``` + +![](flow_example_locationscale.svg) + +We can see that the mode is closely matched, but the tails don't go as deep as the true posterior. +For this, we will need a more "expressive" variational family that is capable of representing nonlinear correlations. + +## Normalizing Flow Variational Family + +Now, let's try to optimize over a variational family formed by normalizing flows. +Normalizing flows, or *flows* for short, is a class of parametric models leveraging neural networks for density estimation. +(For a detailed tutorial on flows, refer to the review by Papamakarios *et al.*[^PNRML2021]) +Within the Julia ecosystem, the package [`NormalizingFlows`](https://github.com/TuringLang/NormalizingFlows.jl) provides a collection of popular flow models. +In this example, we will use the popular `RealNVP`[^DSB2017]. +We will use a standard Gaussian base distribution with three layers, each with 16 hidden units. + +[^PNRML2021]: Papamakarios, G., Nalisnick, E., Rezende, D. J., Mohamed, S., & Lakshminarayanan, B. (2021). Normalizing flows for probabilistic modeling and inference. *Journal of Machine Learning Research*, 22(57), 1-64. +[^DSB2017]: Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2016). Density estimation using real nvp. In *Proceedings of the International Conference on Learning Representations*. +```@example flow +using NormalizingFlows +using Functors + +@leaf MvNormal + +n_layers = 3 +hidden_dims = [16, 16] +q_flow = realnvp(MvNormal(zeros(d), I), hidden_dims, n_layers; paramtype=Float64) +nothing +``` + +Recall that out posterior is constrained. +In most cases, flows assume an unconstrained support. +Therefore, just as with the Gaussian variational family, we can incorporate `Bijectors` to match the supports: + +```@example flow +q_flow_trans = Bijectors.TransformedDistribution(q_flow, binv) +nothing +``` + +For the variational inference algorithms, we will similarly minimize the KL divergence with stochastic gradient descent as originally proposed by Rezende and Mohamed[^RM2015]. +For this, however, we need to be mindful of the requirements of the variational algorithm. +The default objective of `KLMinRepGradDescent` essentially assumes a `MvLocationScale` family is being used: + + - `entropy=RepGradELBO()`: The default `entropy` gradient estimator is `ClosedFormEntropy()`, which assumes that the entropy of the variational family `entropy(q)` is available. For flows, the entropy is (usually) not available. + - `operator=ClipScale()`: The `operator` applied after a gradient descent step is `ClipScale` by default. This operator only works on `MvLocationScale` and `MvLocationScaleLowRank`. + Therefore, we have to customize the two keyword arguments above to make it work with flows. + +In particular, for the `operator`, we will use `IdentityOperator()`, which is a no-op. +For `entropy`, we can use any gradient estimator that only relies on the log-density of the variational family `logpdf(q)`, `StickingTheLandingEntropy()` or `MonteCarloEntropy()`. +Here, we will use `StickingTheLandingEntropy()`[^RWD2017]. +When the variational family is "expressive," this gradient estimator has a variance reduction effect, resulting in faster convergence[^ASD2020]. +Furthermore, Agrawal *et al.*[^AD2025] claim that using a larger number of Monte Carlo samples `n_samples` is beneficial. + +[^RM2015]: Rezende, D., & Mohamed, S. (2015, June). Variational inference with normalizing flows. In *Proceedings of the International conference on machine learning*. PMLR. +[^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. In *Advances in Neural Information Processing Systems*, 30. +[^ASD2020]: Agrawal, A., Sheldon, D. R., & Domke, J. (2020). Advances in black-box VI: Normalizing flows, importance weighting, and optimization. In *Advances in Neural Information Processing Systems*, 33, 17358-17369. +[^AD2025]: Agrawal, A., & Domke, J. (2024). Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI. In *Proceedings of the International Conference on Artificial Intelligence and Statistics*. +```@example flow +alg_flow = KLMinRepGradDescent( + ADTypes.AutoReverseDiff(; compile=true); + n_samples=8, + operator=IdentityOperator(), + entropy=StickingTheLandingEntropy(), +) +nothing +``` + +Without further due, let's now run VI: + +```@example flow +q_flow_out, info_flow, _ = AdvancedVI.optimize( + alg_flow, max_iter, model_ad, q_flow_trans; show_progress=false +) +nothing +``` + +We can do a quick visual diagnostic of whether the optimization went smoothly: + +```@example flow +plot([i.elbo for i in info_flow]; xlabel="Iteration", ylabel="ELBO", ylims=(-10, Inf)) +savefig("flow_example_flow_elbo.svg") +nothing +``` + +![](flow_example_flow_elbo.svg) + +Finally, let's visualize the variational posterior: + +```@example flow +samples = rand(q_flow_out, 10000) +histogram2d( + samples[1, :], + samples[2, :]; + normalize=:pdf, + nbins=64, + xlabel="α", + ylabel="β", + xlims=(0, 4), + ylims=(-3, 25), +) +savefig("flow_example_flow.svg") +nothing +``` + +![](flow_example_flow.svg) + +Compared to the Gaussian approximation, we can see that the tails go much deeper into vertical direction. +This shows that, for this example with extreme nonlinear correlations, normalizing flows enable more accurate approximation. diff --git a/docs/src/tutorials/stan.md b/docs/src/tutorials/stan.md new file mode 100644 index 00000000..5c32eac8 --- /dev/null +++ b/docs/src/tutorials/stan.md @@ -0,0 +1,105 @@ +# Stan Models + +Since `AdvancedVI` supports the [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) interface, it can also be used with Stan models through [`StanLogDensityProblems`](https://github.com/sethaxen/StanLogDensityProblems.jl) interface. +Specifically, `StanLogDensityProblems` wraps any Stan model into a `LogDensityProblem` using [`BridgeStan`](https://github.com/roualdes/bridgestan). + +## Problem Setup + +Recall the hierarchical logistic regression example in the [Basic Example](@ref basic). +Here, we will define the same model in Stan. + +```@example stan +model_src = """ +data { + int N; + int D; + matrix[N,D] X; + array[N] int y; +} +parameters { + vector[D] beta; + real sigma; +} +model { + sigma ~ lognormal(0, 1); + beta ~ normal(0, sigma); + y ~ bernoulli_logit(X * beta); +} +""" +nothing +``` + +We also need to prepare the data. + +```@example stan +using DataFrames: DataFrames +using OpenML: OpenML +using Statistics + +data = Array(DataFrames.DataFrame(OpenML.load(40))) + +X = Matrix{Float64}(data[:, 1:(end - 1)]) +X = (X .- mean(X; dims=2)) ./ std(X; dims=2) +X = hcat(X, ones(size(X, 1))) +y = Vector{Int}(data[:, end] .== "Mine") + +stan_data = (X=transpose(X), y=y, N=size(X, 1), D=size(X, 2)) +nothing +``` + +Since `StanLogDensityProblems` expects files for both the model and the data, we need to store both on the file system. + +```@example stan +using JSON: JSON + +open("logistic_model.stan", "w") do io + println(io, model_src) +end +open("logistic_data.json", "w") do io + println(io, JSON.json(stan_data)) +end +nothing +``` + +## Inference via AdvancedVI + +We can now call `StanLogDensityProblems` to recieve a `LogDensityProblem`. + +```@example stan +using StanLogDensityProblems: StanLogDensityProblems + +model = StanLogDensityProblems.StanProblem("logistic_model.stan", "logistic_data.json") +nothing +``` + +The rest is the same as all `LogDensityProblem` with the exception of how to deal with constrainted variables: Since `StanLogDensityProblems` automatically transforms the support of the target problem to be unconstrained, we do not need to involve `Bijectors`. + +```@example stan +using ADTypes, ReverseDiff +using AdvancedVI +using LinearAlgebra +using LogDensityProblems +using Plots + +alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff()) + +d = LogDensityProblems.dimension(model) +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d))) + +max_iter = 10^4 +q_out, info, _ = AdvancedVI.optimize(alg, max_iter, model, q; show_progress=false) + +plot( + [i.iteration for i in info], + [i.elbo for i in info]; + xlabel="Iteration", + ylabel="ELBO", + label=nothing, +) +savefig("stan_example_elbo.svg") +``` + +![](stan_example_elbo.svg) + +From variational posterior `q_out` we can draw samples from the unconstrained support of the model. +To convert the samples back to the original (constrained) support of the model, it suffices to call [BridgeStan.param_constrain](https://roualdes.us/bridgestan/latest/languages/julia.html#BridgeStan.param_constrain). diff --git a/docs/src/tutorials/subsampling.md b/docs/src/tutorials/subsampling.md new file mode 100644 index 00000000..9baa7906 --- /dev/null +++ b/docs/src/tutorials/subsampling.md @@ -0,0 +1,310 @@ +# Scaling to Large Datasets with Subsampling + +In this tutorial, we will show how to use `AdvancedVI` on problems with large datasets. +Variational inference (VI) has a long and successful history[^HBWP2013][^TL2014][^HBB2010] in large scale inference using (minibatch) subsampling. +In this tutorial, we will see how to perform subsampling with `KLMinRepGradProxDescent`, which was originally described in the paper by Titsias and Lázaro-Gredilla[^TL2014]; Kucukelbir *et al*[^KTRGB2017]. + +[^HBB2010]: Hoffman, M., Bach, F., & Blei, D. (2010). Online learning for latent Dirichlet allocation. In *Advances in Neural Information Processing Systems*, 23. +[^HBWP2013]: Hoffman, M. D., Blei, D. M., Wang, C., & Paisley, J. (2013). Stochastic variational inference. *Journal of Machine Learning Research*, 14(1), 1303-1347. +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *Proceedings of the International Conference on Machine Learning* (pp. 1971-1979). PMLR. +[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of Machine Learning Research*, 18(14), 1-45. +## Setting Up Subsampling + +We will consider the same hierarchical logistic regression example used in the [Basic Example](@ref basic). + +```@example subsampling +using LogDensityProblems: LogDensityProblems +using Distributions +using FillArrays + +struct LogReg{XType,YType} + X::XType + y::YType + n_data::Int +end + +function LogDensityProblems.logdensity(model::LogReg, θ) + (; X, y, n_data) = model + n, d = size(X) + β, σ = θ[1:size(X, 2)], θ[end] + + logprior_β = logpdf(MvNormal(Zeros(d), σ), β) + logprior_σ = logpdf(LogNormal(0, 3), σ) + + logit = X*β + loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y) + return n_data/n*loglike_y + logprior_β + logprior_σ +end + +function LogDensityProblems.dimension(model::LogReg) + return size(model.X, 2) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:LogReg}) + return LogDensityProblems.LogDensityOrder{0}() +end +nothing +``` + +Notice that, to use subsampling, we need be able to rescale the likelihood strength. +That is, for the gradient of the log-density with a batch of data points of size `n` to be an unbiased estimate of the gradient using the full dataset of size `n_data`, we need to scale the likelihood by `n_data/n`. +This part is critical to ensure that the algorithm correctly approximates the posterior with the full dataset. + +As usual, we will set up a bijector: + +```@example subsampling +using Bijectors: Bijectors + +function Bijectors.bijector(model::LogReg) + d = size(model.X, 2) + return Bijectors.Stacked( + Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]), + [1:d, (d + 1):(d + 1)], + ) +end +nothing +``` + +For the dataset, we will use one that is larger than that used in the [Basic Example](@ref basic). +This is to properly assess the advantage of subsampling. +In particular, we will utilize the "Phishing" dataset[^Tan2018], which consists of 10000 data points, each with 48 features. +The goal is to predict whether the features of a specific website indicate whether it is a phishing website or a legitimate one. +The [dataset](https://www.openml.org/search?type=data&status=active&id=46722) id on the [`OpenML`](https://github.com/JuliaAI/OpenML.jl) repository is 46722. + +[^Tan2018]: Tan, Choon Lin (2018), "Phishing Dataset for Machine Learning: Feature Evaluation", Mendeley Data, V1, doi: 10.17632/h3cgnj8hft.1] +```@example subsampling +using OpenML: OpenML +using DataFrames: DataFrames + +data = Array(DataFrames.DataFrame(OpenML.load(46722))) +X = Matrix{Float64}(data[:, 2:end]) +y = Vector{Bool}(data[:, end]) +nothing +``` + +The features start from the seoncd column, while the last column are the class labels. + +Let's also apply some basic pre-processing. + +```@example subsampling +using Statistics + +X = (X .- mean(X; dims=2)) ./ std(X; dims=2) +X = hcat(X, ones(size(X, 1))) +nothing +``` + +Let's now istantiate the model and set up automatic differentiation using [`LogDensityProblemsAD`](https://github.com/tpapp/LogDensityProblemsAD.jl?tab=readme-ov-file). + +```@example subsampling +using ADTypes, ReverseDiff +using LogDensityProblemsAD + +model = LogReg(X, y, size(X, 1)) +model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model) +nothing +``` + +To enable subsampling, `LogReg` has to implement the method `AdvancedVI.subsample`. +For our model, this is fairly simple: We only need to select the rows of `X` and the elements of `y` corresponding to the batch of data points. +As subtle point here is that we wrapped `model` with `LogDensityProblemsAD.ADgradient` into `model_ad`. +Therefore, `AdvancedVI` sees `model_ad` and not `model`. +This means we have to specialize `AdvancedVI.subsample` to `typeof(model_ad)` and not `LogReg`. + +```@example subsampling +using Accessors +using AdvancedVI + +function AdvancedVI.subsample(model::typeof(model_ad), idx) + (; X, y, n_data) = parent(model) + model′ = @set model.ℓ.X = X[idx, :] + model′′ = @set model′.ℓ.y = y[idx] + return model′′ +end +nothing +``` + +!!! info + + The default implementation of `AdvancedVI.subsample` is `AdvancedVI.subsample(model, idx) = model`. + Therefore, if the specialization of `AdvancedVI.subsample` is not set up properly, `AdvancedVI` will silently use full-batch gradients instead of subsampling. + It is thus useful to check whether the right specialization of `AdvancedVI.subsample` is being called. + +## Scalable Inference via AdvancedVI + +In this example, we will compare the convergence speed of `KLMinRepGradProxDescent` with and without subsampling. +Subsampling can be turned on by supplying a subsampling strategy. +Here, we will use `ReshufflingBatchSubsampling`, which implements random reshuffling. +We will us a batch size of 32, which results in `313 = length(subsampling) = ceil(Int, size(X,2)/32)` steps per epoch. + +```@example subsampling +dataset = 1:size(model.X, 1) +batchsize = 32 +subsampling = ReshufflingBatchSubsampling(dataset, batchsize) +alg_sub = KLMinRepGradProxDescent(ADTypes.AutoReverseDiff(; compile=true); subsampling) +nothing +``` + +Recall that each epoch is 313 steps. +When using `ReshufflingBatchSubsampling`, it is best to choose the number of iterations to be a multiple of the number of steps `length(subsampling)` in an epoch. +This is due to a peculiar property of `ReshufflingBatchSubsampling`: the objective value tends to *increase* during an epoch, and come down nearing the end. (Theoretically, this is due to conditionally *biased* nature of random reshuffling[^MKR2020].) +Therefore, the objective value is minimized exactly after the last step of each epoch. + +[^MKR2020]: Mishchenko, K., Khaled, A., & Richtárik, P. (2020). Random reshuffling: Simple analysis with vast improvements. Advances in Neural Information Processing Systems, 33, 17309-17320. +```@example subsampling +num_epochs = 10 +max_iter = num_epochs * length(subsampling) +nothing +``` + +If we don't supply a subsampling strategy to `KLMinRepGradProxDescent`, subsampling will not be used. + +```@example subsampling +alg_full = KLMinRepGradProxDescent(ADTypes.AutoReverseDiff(; compile=true)) +nothing +``` + +The variational family will be set up as follows: + +```@example subsampling +using LinearAlgebra + +d = LogDensityProblems.dimension(model_ad) +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d))) +b = Bijectors.bijector(model) +binv = Bijectors.inverse(b) +q_transformed = Bijectors.TransformedDistribution(q, binv) +nothing +``` + +It now remains to run VI. +For comparison, we will record both the ELBO (with a large number of Monte Carlo samples) and the prediction accuracy. + +```@example subsampling +using StatsFuns: StatsFuns + +logging_interval = 100 +time_begin = nothing + +""" + logistic_prediction(X, μ_β, Σ_β) + +Approximate the posterior predictive probability for a logistic link function using Mackay's approximation (Bishop p. 220). +""" +function logistic_prediction(X, μ_β, Σ_β) + xtΣx = sum((model.X*Σ_β) .* model.X; dims=2)[:, 1] + κ = @. 1/sqrt(1 + π/8*xtΣx) + return StatsFuns.logistic.(κ .* X*μ_β) +end + +function callback(; iteration, averaged_params, restructure, kwargs...) + if mod(iteration, logging_interval) == 1 + + # Use the averaged parameters (the eventual output of the algorithm) + q_avg = restructure(averaged_params) + + # Compute predictions using + μ_β = mean(q_avg.dist)[1:(end - 1)] # posterior mean of β + Σ_β = cov(q_avg.dist)[1:(end - 1), end - 1] # marginal posterior covariance of β + y_pred = logistic_prediction(X, μ_β, Σ_β) .> 0.5 + + # Prediction accuracy + acc = mean(y_pred .== model.y) + + # Higher fidelity estimate of the ELBO on the averaged parameters + n_samples = 256 + obj = AdvancedVI.RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + elbo_callback = estimate_objective(obj, q_avg, model) + + (elbo_callback=elbo_callback, accuracy=acc, time_elapsed=time() - time_begin) + else + nothing + end +end + +time_begin = time() +_, info_full, _ = AdvancedVI.optimize( + alg_full, max_iter, model_ad, q_transformed; show_progress=false, callback +); + +time_begin = time() +_, info_sub, _ = AdvancedVI.optimize( + alg_sub, max_iter, model_ad, q_transformed; show_progress=false, callback +); +nothing +``` + +Let's visualize the evolution of the ELBO. + +```@example subsampling +using Plots + +t = 1:logging_interval:max_iter +plot( + [i.iteration for i in info_full[t]], + [i.elbo_callback for i in info_full[t]]; + xlabel="Iteration", + ylabel="ELBO", + label="Full Batch", +) +plot!( + [i.iteration for i in info_sub[t]], + [i.elbo_callback for i in info_sub[t]]; + label="Subsampling", +) +savefig("subsampling_example_iteration_elbo.svg") +nothing +``` + +![](subsampling_example_iteration_elbo.svg) + +According to this plot, it might seem like subsampling has no benefit (if not detrimental). +This is, however, because we are plotting against the number of iterations. +Subsampling generally converges slower (asymptotically) in terms of iterations. +But in return, it reduces the time spent at each iteration. +Therefore, we need to plot against the elapsed time: + +```@example subsampling +plot( + [i.time_elapsed for i in info_full[t]], + [i.elbo_callback for i in info_full[t]]; + xlabel="Wallclock Time (sec)", + ylabel="ELBO", + label="Full Batch", +) +plot!( + [i.time_elapsed for i in info_sub[t]], + [i.elbo_callback for i in info_sub[t]]; + label="Subsampling", +) +savefig("subsampling_example_time_elbo.svg") +nothing +``` + +![](subsampling_example_time_elbo.svg) + +We can now see the dramatic effect of subsampling. +The picture is similar if we visualize the prediction accuracy over time. + +```@example subsampling +plot( + [i.time_elapsed for i in info_full[t]], + [i.accuracy for i in info_full[t]]; + xlabel="Wallclock Time (sec)", + ylabel="Prediction Accuracy", + label="Full Batch", +) +plot!( + [i.time_elapsed for i in info_sub[t]], + [i.accuracy for i in info_sub[t]]; + label="Subsampling", +) +savefig("subsampling_example_time_accuracy.svg") +nothing +``` + +![](subsampling_example_time_accuracy.svg) + +But remember that subsampling will always be *asymptotically* slower than no subsampling. +That is, as the number of iterations increase, there will be a point where no subsampling will overtake subsampling even in terms of wallclock time. +Therefore, subsampling is most beneficial when a crude solution to the VI problem suffices.