Skip to content

Commit 49236af

Browse files
committed
fix the positive definite preserving update rule in NGVI
1 parent f7f965a commit 49236af

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

src/algorithms/gauss_expected_grad_hess.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Otherwise, it uses Stein's identity.
1919
"""
2020
function gaussian_expectation_gradient_and_hessian!(
2121
rng::Random.AbstractRNG,
22-
q::MvLocationScale{<:LowerTriangular,<:Normal,L},
22+
q::MvLocationScale{<:LinearAlgebra.AbstractTriangular,<:Normal,L},
2323
n_samples::Int,
2424
grad_buf::AbstractVector{T},
2525
hess_buf::AbstractMatrix{T},

src/algorithms/klminnaturalgraddescent.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ The keyword arguments are as follows:
4949
subsampling::Sub = nothing
5050
end
5151

52-
struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf}
52+
struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov,GradBuf,HessBuf}
5353
q::Q
5454
prob::P
5555
prec::Prec
56+
qcov::QCov
5657
iteration::Int
5758
sub_st::S
5859
grad_buf::GradBuf
@@ -78,8 +79,13 @@ function init(
7879
sub_st = isnothing(sub) ? nothing : init(rng, sub)
7980
grad_buf = Vector{eltype(q_init.location)}(undef, n_dims)
8081
hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims)
82+
scale = q_init.scale
83+
qcov = Hermitian(scale*scale')
84+
scale_inv = inv(scale)
85+
prec_chol = scale_inv'
86+
prec = Hermitian(prec_chol*prec_chol')
8187
return KLMinNaturalGradDescentState(
82-
q_init, prob, inv(cov(q_init)), 0, sub_st, grad_buf, hess_buf
88+
q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf
8389
)
8490
end
8591

@@ -94,7 +100,7 @@ function step(
94100
kwargs...,
95101
)
96102
(; ensure_posdef, n_samples, stepsize, subsampling) = alg
97-
(; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state
103+
(; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state
98104

99105
m = mean(q)
100106
S = prec
@@ -114,17 +120,26 @@ function step(
114120
rng, q, n_samples, grad_buf, hess_buf, prob_sub
115121
)
116122

117-
S′ = Hermitian(((1 - η) * S + η * Symmetric(-hess_buf)))
118-
if ensure_posdef
123+
S′ = if ensure_posdef
124+
# Udpate rule guaranteeing positive definiteness in the proof of Theorem 1.
125+
# Lin, W., Schmidt, M., & Khan, M. E.
126+
# Handling the positive-definite constraint in the Bayesian learning rule.
127+
# In ICML 2020.
119128
G_hat = S - Symmetric(-hess_buf)
120-
S′ += η^2 / 2 * Symmetric(G_hat * (S′ \ G_hat))
129+
Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat)
130+
else
131+
Hermitian(((1 - η) * S + η * Symmetric(-hess_buf)))
121132
end
122133
m′ = m - η * (S′ \ (-grad_buf))
123134

124-
q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist)
135+
prec_chol = cholesky(S′).L
136+
prec_chol_inv = inv(prec_chol)
137+
scale = prec_chol_inv'
138+
qcov = Hermitian(scale*scale')
139+
q′ = MvLocationScale(m′, scale, q.dist)
125140

126141
state = KLMinNaturalGradDescentState(
127-
q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf
142+
q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf
128143
)
129144
elbo = logπ_avg + entropy(q′)
130145
info = merge((elbo=elbo,), sub_inf)

0 commit comments

Comments
 (0)