Skip to content

Commit 1c6f1de

Browse files
Red-Portalgithub-actions[bot]sunxd3
authored
Add natural gradient variational inference algorithms (#211)
* move gaussian expectation of grad and hess to its own file * add square-root variational newton algorithm * apply formatter * add natural gradient descent (variational online Newton) * update docstrings remove redundant comments * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * update history * fix gauss expected grad hess, use in-place operations, add tests * fix always wrap `hess_buf` with a `Symmetric` (not `Hermitian`) * Apply suggestion from @sunxd3 Co-authored-by: Xianda Sun <[email protected]> * Apply suggestion from @sunxd3 Co-authored-by: Xianda Sun <[email protected]> * Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix bug in init of klminnaturalgraddescent * remove unintended benchmark code * update docs * fix relax Hermitian to Symmetric in NGVI ensure posdef * fix gauss expected grad hess * fix callback argument in measure space algorithms * fix the positive definite preserving update rule in NGVI --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent b34fa31 commit 1c6f1de

File tree

10 files changed

+826
-60
lines changed

10 files changed

+826
-60
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ This update adds new variational inference algorithms in light of the flexibilit
44
Specifically, the following measure-space optimization algorithms have been added:
55

66
- `KLMinWassFwdBwd`
7+
- `KLMinNaturalGradDescent`
8+
- `KLMinSqrtNaturalGradDescent`
79

810
In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` will now throw a `RuntimException` if the objective value estimated at each step turns out to be degenerate (`Inf` or `NaN`). Previously, the algorithms ran until `max_iter` even if the optimization run has failed.
911

src/AdvancedVI.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,13 @@ include("algorithms/common.jl")
352352

353353
export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI
354354

355-
# Other Algorithms
355+
# Natural and Wasserstein gradient descent algorithms
356356

357+
include("algorithms/gauss_expected_grad_hess.jl")
357358
include("algorithms/klminwassfwdbwd.jl")
359+
include("algorithms/klminsqrtnaturalgraddescent.jl")
360+
include("algorithms/klminnaturalgraddescent.jl")
358361

359-
export KLMinWassFwdBwd
362+
export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent
360363

361364
end
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
"""
3+
gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob)
4+
5+
Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`.
6+
For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian.
7+
Otherwise, it uses Stein's identity.
8+
9+
!!! warning
10+
The resulting `hess_buf` may not be perfectly symmetric due to numerical issues. It is therefore useful to wrap it in a `Symmetric` before usage.
11+
12+
# Arguments
13+
- `rng::Random.AbstractRNG`: Random number generator.
14+
- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over.
15+
- `n_samples::Int`: Number of samples used for estimation.
16+
- `grad_buf::AbstractVector`: Buffer for the gradient estimate.
17+
- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate.
18+
- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation.
19+
"""
20+
function gaussian_expectation_gradient_and_hessian!(
21+
rng::Random.AbstractRNG,
22+
q::MvLocationScale{<:LinearAlgebra.AbstractTriangular,<:Normal,L},
23+
n_samples::Int,
24+
grad_buf::AbstractVector{T},
25+
hess_buf::AbstractMatrix{T},
26+
prob,
27+
) where {T<:Real,L}
28+
logπ_avg = zero(T)
29+
fill!(grad_buf, zero(T))
30+
fill!(hess_buf, zero(T))
31+
32+
if LogDensityProblems.capabilities(typeof(prob))
33+
LogDensityProblems.LogDensityOrder{1}()
34+
# First-order-only: use Stein/Price identity for the Hessian
35+
#
36+
# E_{z ~ N(m, CC')} ∇2 log π(z)
37+
# = E_{z ~ N(m, CC')} (CC')^{-1} (z - m) ∇ log π(z)T
38+
# = E_{u ~ N(0, I)} C' \ (u ∇ log π(z)T) .
39+
#
40+
# Algorithmically, draw u ~ N(0, I), z = C u + m, where C = q.scale.
41+
# Accumulate A = E[ u ∇ log π(z)T ], then map back: H = C \ A.
42+
d = LogDensityProblems.dimension(prob)
43+
u = randn(rng, T, d, n_samples)
44+
m, C = q.location, q.scale
45+
z = C*u .+ m
46+
for b in 1:n_samples
47+
zb, ub = view(z, :, b), view(u, :, b)
48+
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
49+
logπ_avg += logπ/n_samples
50+
51+
rdiv!(∇logπ, n_samples)
52+
∇logπ_div_nsamples = ∇logπ
53+
54+
grad_buf[:] .+= ∇logπ_div_nsamples
55+
hess_buf[:, :] .+= ub*∇logπ_div_nsamples'
56+
end
57+
hess_buf[:, :] .= C' \ hess_buf
58+
return logπ_avg, grad_buf, hess_buf
59+
else
60+
# Second-order: use naive sample average
61+
z = rand(rng, q, n_samples)
62+
for b in 1:n_samples
63+
zb = view(z, :, b)
64+
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
65+
prob, zb
66+
)
67+
68+
rdiv!(∇logπ, n_samples)
69+
∇logπ_div_nsamples = ∇logπ
70+
71+
rdiv!(∇2logπ, n_samples)
72+
∇2logπ_div_nsamples = ∇2logπ
73+
74+
logπ_avg += logπ/n_samples
75+
grad_buf[:] .+= ∇logπ_div_nsamples
76+
hess_buf[:, :] .+= ∇2logπ_div_nsamples
77+
end
78+
return logπ_avg, grad_buf, hess_buf
79+
end
80+
end
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
2+
"""
3+
KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling)
4+
KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling)
5+
6+
KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton.
7+
This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence.
8+
9+
If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness.
10+
This, however, involves an additional set of matrix-matrix system solves that could be costly.
11+
12+
The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation.
13+
If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian.
14+
If the target has only first-order capability, we use Stein's identity.
15+
16+
# (Keyword) Arguments
17+
- `stepsize::Float64`: Step size.
18+
- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`)
19+
- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`)
20+
- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy.
21+
22+
!!! note
23+
The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families.
24+
25+
# Output
26+
- `q`: The last iterate of the algorithm.
27+
28+
# Callback Signature
29+
The `callback` function supplied to `optimize` needs to have the following signature:
30+
31+
callback(; rng, iteration, q, info)
32+
33+
The keyword arguments are as follows:
34+
- `rng`: Random number generator internally used by the algorithm.
35+
- `iteration`: The index of the current iteration.
36+
- `q`: Current variational approximation.
37+
- `info`: `NamedTuple` containing the information generated during the current iteration.
38+
39+
# Requirements
40+
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
41+
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
42+
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
43+
"""
44+
@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
45+
AbstractVariationalAlgorithm
46+
stepsize::Float64
47+
n_samples::Int = 1
48+
ensure_posdef::Bool = true
49+
subsampling::Sub = nothing
50+
end
51+
52+
struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov,GradBuf,HessBuf}
53+
q::Q
54+
prob::P
55+
prec::Prec
56+
qcov::QCov
57+
iteration::Int
58+
sub_st::S
59+
grad_buf::GradBuf
60+
hess_buf::HessBuf
61+
end
62+
63+
function init(
64+
rng::Random.AbstractRNG,
65+
alg::KLMinNaturalGradDescent,
66+
q_init::MvLocationScale{<:LowerTriangular,<:Normal,L},
67+
prob,
68+
) where {L}
69+
sub = alg.subsampling
70+
n_dims = LogDensityProblems.dimension(prob)
71+
capability = LogDensityProblems.capabilities(typeof(prob))
72+
if capability < LogDensityProblems.LogDensityOrder{1}()
73+
throw(
74+
ArgumentError(
75+
"`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).",
76+
),
77+
)
78+
end
79+
sub_st = isnothing(sub) ? nothing : init(rng, sub)
80+
grad_buf = Vector{eltype(q_init.location)}(undef, n_dims)
81+
hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims)
82+
scale = q_init.scale
83+
qcov = Hermitian(scale*scale')
84+
scale_inv = inv(scale)
85+
prec_chol = scale_inv'
86+
prec = Hermitian(prec_chol*prec_chol')
87+
return KLMinNaturalGradDescentState(
88+
q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf
89+
)
90+
end
91+
92+
output(::KLMinNaturalGradDescent, state) = state.q
93+
94+
function step(
95+
rng::Random.AbstractRNG,
96+
alg::KLMinNaturalGradDescent,
97+
state,
98+
callback,
99+
objargs...;
100+
kwargs...,
101+
)
102+
(; ensure_posdef, n_samples, stepsize, subsampling) = alg
103+
(; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state
104+
105+
m = mean(q)
106+
S = prec
107+
η = convert(eltype(m), stepsize)
108+
iteration += 1
109+
110+
# Maybe apply subsampling
111+
prob_sub, sub_st′, sub_inf = if isnothing(subsampling)
112+
prob, sub_st, NamedTuple()
113+
else
114+
batch, sub_st′, sub_inf = step(rng, subsampling, sub_st)
115+
prob_sub = subsample(prob, batch)
116+
prob_sub, sub_st′, sub_inf
117+
end
118+
119+
logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!(
120+
rng, q, n_samples, grad_buf, hess_buf, prob_sub
121+
)
122+
123+
S′ = if ensure_posdef
124+
# Udpate rule guaranteeing positive definiteness in the proof of Theorem 1.
125+
# Lin, W., Schmidt, M., & Khan, M. E.
126+
# Handling the positive-definite constraint in the Bayesian learning rule.
127+
# In ICML 2020.
128+
G_hat = S - Symmetric(-hess_buf)
129+
Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat)
130+
else
131+
Hermitian(((1 - η) * S + η * Symmetric(-hess_buf)))
132+
end
133+
m′ = m - η * (S′ \ (-grad_buf))
134+
135+
prec_chol = cholesky(S′).L
136+
prec_chol_inv = inv(prec_chol)
137+
scale = prec_chol_inv'
138+
qcov = Hermitian(scale*scale')
139+
q′ = MvLocationScale(m′, scale, q.dist)
140+
141+
state = KLMinNaturalGradDescentState(
142+
q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf
143+
)
144+
elbo = logπ_avg + entropy(q′)
145+
info = merge((elbo=elbo,), sub_inf)
146+
147+
if !isnothing(callback)
148+
info′ = callback(; rng, iteration, q=q′, info)
149+
info = !isnothing(info′) ? merge(info′, info) : info
150+
end
151+
state, false, info
152+
end
153+
154+
"""
155+
estimate_objective([rng,] alg, q, prob; n_samples)
156+
157+
Estimate the ELBO of the variational approximation `q` against the target log-density `prob`.
158+
159+
# Arguments
160+
- `rng::Random.AbstractRNG`: Random number generator.
161+
- `alg::KLMinNaturalGradDescent`: Variational inference algorithm.
162+
- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation.
163+
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
164+
165+
# Keyword Arguments
166+
- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)
167+
168+
# Returns
169+
- `obj_est`: Estimate of the objective value.
170+
"""
171+
function estimate_objective(
172+
rng::Random.AbstractRNG,
173+
alg::KLMinNaturalGradDescent,
174+
q::MvLocationScale{S,<:Normal,L},
175+
prob;
176+
n_samples::Int=alg.n_samples,
177+
) where {S,L}
178+
obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy())
179+
if isnothing(alg.subsampling)
180+
return estimate_objective(rng, obj, q, prob)
181+
else
182+
sub = alg.subsampling
183+
sub_st = init(rng, sub)
184+
return mapreduce(+, 1:length(sub)) do _
185+
batch, sub_st, _ = step(rng, sub, sub_st)
186+
prob_sub = subsample(prob, batch)
187+
estimate_objective(rng, obj, q, prob_sub) / length(sub)
188+
end
189+
end
190+
end

0 commit comments

Comments
 (0)