|
| 1 | + |
| 2 | +""" |
| 3 | + KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling) |
| 4 | + KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling) |
| 5 | +
|
| 6 | +KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton. |
| 7 | +This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence. |
| 8 | +
|
| 9 | +If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. |
| 10 | +This, however, involves an additional set of matrix-matrix system solves that could be costly. |
| 11 | +
|
| 12 | +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. |
| 13 | +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. |
| 14 | +If the target has only first-order capability, we use Stein's identity. |
| 15 | +
|
| 16 | +# (Keyword) Arguments |
| 17 | +- `stepsize::Float64`: Step size. |
| 18 | +- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) |
| 19 | +- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) |
| 20 | +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. |
| 21 | +
|
| 22 | +!!! note |
| 23 | + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families. |
| 24 | +
|
| 25 | +# Output |
| 26 | +- `q`: The last iterate of the algorithm. |
| 27 | +
|
| 28 | +# Callback Signature |
| 29 | +The `callback` function supplied to `optimize` needs to have the following signature: |
| 30 | +
|
| 31 | + callback(; rng, iteration, q, info) |
| 32 | +
|
| 33 | +The keyword arguments are as follows: |
| 34 | +- `rng`: Random number generator internally used by the algorithm. |
| 35 | +- `iteration`: The index of the current iteration. |
| 36 | +- `q`: Current variational approximation. |
| 37 | +- `info`: `NamedTuple` containing the information generated during the current iteration. |
| 38 | +
|
| 39 | +# Requirements |
| 40 | +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). |
| 41 | +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). |
| 42 | +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. |
| 43 | +""" |
| 44 | +@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: |
| 45 | + AbstractVariationalAlgorithm |
| 46 | + stepsize::Float64 |
| 47 | + n_samples::Int = 1 |
| 48 | + ensure_posdef::Bool = true |
| 49 | + subsampling::Sub = nothing |
| 50 | +end |
| 51 | + |
| 52 | +struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov,GradBuf,HessBuf} |
| 53 | + q::Q |
| 54 | + prob::P |
| 55 | + prec::Prec |
| 56 | + qcov::QCov |
| 57 | + iteration::Int |
| 58 | + sub_st::S |
| 59 | + grad_buf::GradBuf |
| 60 | + hess_buf::HessBuf |
| 61 | +end |
| 62 | + |
| 63 | +function init( |
| 64 | + rng::Random.AbstractRNG, |
| 65 | + alg::KLMinNaturalGradDescent, |
| 66 | + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, |
| 67 | + prob, |
| 68 | +) where {L} |
| 69 | + sub = alg.subsampling |
| 70 | + n_dims = LogDensityProblems.dimension(prob) |
| 71 | + capability = LogDensityProblems.capabilities(typeof(prob)) |
| 72 | + if capability < LogDensityProblems.LogDensityOrder{1}() |
| 73 | + throw( |
| 74 | + ArgumentError( |
| 75 | + "`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", |
| 76 | + ), |
| 77 | + ) |
| 78 | + end |
| 79 | + sub_st = isnothing(sub) ? nothing : init(rng, sub) |
| 80 | + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) |
| 81 | + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) |
| 82 | + scale = q_init.scale |
| 83 | + qcov = Hermitian(scale*scale') |
| 84 | + scale_inv = inv(scale) |
| 85 | + prec_chol = scale_inv' |
| 86 | + prec = Hermitian(prec_chol*prec_chol') |
| 87 | + return KLMinNaturalGradDescentState( |
| 88 | + q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf |
| 89 | + ) |
| 90 | +end |
| 91 | + |
| 92 | +output(::KLMinNaturalGradDescent, state) = state.q |
| 93 | + |
| 94 | +function step( |
| 95 | + rng::Random.AbstractRNG, |
| 96 | + alg::KLMinNaturalGradDescent, |
| 97 | + state, |
| 98 | + callback, |
| 99 | + objargs...; |
| 100 | + kwargs..., |
| 101 | +) |
| 102 | + (; ensure_posdef, n_samples, stepsize, subsampling) = alg |
| 103 | + (; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state |
| 104 | + |
| 105 | + m = mean(q) |
| 106 | + S = prec |
| 107 | + η = convert(eltype(m), stepsize) |
| 108 | + iteration += 1 |
| 109 | + |
| 110 | + # Maybe apply subsampling |
| 111 | + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) |
| 112 | + prob, sub_st, NamedTuple() |
| 113 | + else |
| 114 | + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) |
| 115 | + prob_sub = subsample(prob, batch) |
| 116 | + prob_sub, sub_st′, sub_inf |
| 117 | + end |
| 118 | + |
| 119 | + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( |
| 120 | + rng, q, n_samples, grad_buf, hess_buf, prob_sub |
| 121 | + ) |
| 122 | + |
| 123 | + S′ = if ensure_posdef |
| 124 | + # Udpate rule guaranteeing positive definiteness in the proof of Theorem 1. |
| 125 | + # Lin, W., Schmidt, M., & Khan, M. E. |
| 126 | + # Handling the positive-definite constraint in the Bayesian learning rule. |
| 127 | + # In ICML 2020. |
| 128 | + G_hat = S - Symmetric(-hess_buf) |
| 129 | + Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat) |
| 130 | + else |
| 131 | + Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) |
| 132 | + end |
| 133 | + m′ = m - η * (S′ \ (-grad_buf)) |
| 134 | + |
| 135 | + prec_chol = cholesky(S′).L |
| 136 | + prec_chol_inv = inv(prec_chol) |
| 137 | + scale = prec_chol_inv' |
| 138 | + qcov = Hermitian(scale*scale') |
| 139 | + q′ = MvLocationScale(m′, scale, q.dist) |
| 140 | + |
| 141 | + state = KLMinNaturalGradDescentState( |
| 142 | + q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf |
| 143 | + ) |
| 144 | + elbo = logπ_avg + entropy(q′) |
| 145 | + info = merge((elbo=elbo,), sub_inf) |
| 146 | + |
| 147 | + if !isnothing(callback) |
| 148 | + info′ = callback(; rng, iteration, q=q′, info) |
| 149 | + info = !isnothing(info′) ? merge(info′, info) : info |
| 150 | + end |
| 151 | + state, false, info |
| 152 | +end |
| 153 | + |
| 154 | +""" |
| 155 | + estimate_objective([rng,] alg, q, prob; n_samples) |
| 156 | +
|
| 157 | +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. |
| 158 | +
|
| 159 | +# Arguments |
| 160 | +- `rng::Random.AbstractRNG`: Random number generator. |
| 161 | +- `alg::KLMinNaturalGradDescent`: Variational inference algorithm. |
| 162 | +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. |
| 163 | +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. |
| 164 | +
|
| 165 | +# Keyword Arguments |
| 166 | +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) |
| 167 | +
|
| 168 | +# Returns |
| 169 | +- `obj_est`: Estimate of the objective value. |
| 170 | +""" |
| 171 | +function estimate_objective( |
| 172 | + rng::Random.AbstractRNG, |
| 173 | + alg::KLMinNaturalGradDescent, |
| 174 | + q::MvLocationScale{S,<:Normal,L}, |
| 175 | + prob; |
| 176 | + n_samples::Int=alg.n_samples, |
| 177 | +) where {S,L} |
| 178 | + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) |
| 179 | + if isnothing(alg.subsampling) |
| 180 | + return estimate_objective(rng, obj, q, prob) |
| 181 | + else |
| 182 | + sub = alg.subsampling |
| 183 | + sub_st = init(rng, sub) |
| 184 | + return mapreduce(+, 1:length(sub)) do _ |
| 185 | + batch, sub_st, _ = step(rng, sub, sub_st) |
| 186 | + prob_sub = subsample(prob, batch) |
| 187 | + estimate_objective(rng, obj, q, prob_sub) / length(sub) |
| 188 | + end |
| 189 | + end |
| 190 | +end |
0 commit comments