@@ -49,10 +49,11 @@ The keyword arguments are as follows:
4949 subsampling:: Sub = nothing
5050end
5151
52- struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf}
52+ struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov, GradBuf,HessBuf}
5353 q:: Q
5454 prob:: P
5555 prec:: Prec
56+ qcov:: QCov
5657 iteration:: Int
5758 sub_st:: S
5859 grad_buf:: GradBuf
@@ -78,8 +79,13 @@ function init(
7879 sub_st = isnothing (sub) ? nothing : init (rng, sub)
7980 grad_buf = Vector {eltype(q_init.location)} (undef, n_dims)
8081 hess_buf = Matrix {eltype(q_init.location)} (undef, n_dims, n_dims)
82+ scale = q_init. scale
83+ qcov = Hermitian (scale* scale' )
84+ scale_inv = inv (scale)
85+ prec_chol = scale_inv'
86+ prec = Hermitian (prec_chol* prec_chol' )
8187 return KLMinNaturalGradDescentState (
82- q_init, prob, inv ( cov (q_init)) , 0 , sub_st, grad_buf, hess_buf
88+ q_init, prob, prec, qcov , 0 , sub_st, grad_buf, hess_buf
8389 )
8490end
8591
@@ -94,7 +100,7 @@ function step(
94100 kwargs... ,
95101)
96102 (; ensure_posdef, n_samples, stepsize, subsampling) = alg
97- (; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state
103+ (; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state
98104
99105 m = mean (q)
100106 S = prec
@@ -114,17 +120,26 @@ function step(
114120 rng, q, n_samples, grad_buf, hess_buf, prob_sub
115121 )
116122
117- S′ = Hermitian (((1 - η) * S + η * Symmetric (- hess_buf)))
118- if ensure_posdef
123+ S′ = if ensure_posdef
124+ # Udpate rule guaranteeing positive definiteness in the proof of Theorem 1.
125+ # Lin, W., Schmidt, M., & Khan, M. E.
126+ # Handling the positive-definite constraint in the Bayesian learning rule.
127+ # In ICML 2020.
119128 G_hat = S - Symmetric (- hess_buf)
120- S′ += η^ 2 / 2 * Symmetric (G_hat * (S′ \ G_hat))
129+ Hermitian (S - η* G_hat + η^ 2 / 2 * G_hat* qcov* G_hat)
130+ else
131+ Hermitian (((1 - η) * S + η * Symmetric (- hess_buf)))
121132 end
122133 m′ = m - η * (S′ \ (- grad_buf))
123134
124- q′ = MvLocationScale (m′, inv (cholesky (S′). L), q. dist)
135+ prec_chol = cholesky (S′). L
136+ prec_chol_inv = inv (prec_chol)
137+ scale = prec_chol_inv'
138+ qcov = Hermitian (scale* scale' )
139+ q′ = MvLocationScale (m′, scale, q. dist)
125140
126141 state = KLMinNaturalGradDescentState (
127- q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf
142+ q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf
128143 )
129144 elbo = logπ_avg + entropy (q′)
130145 info = merge ((elbo= elbo,), sub_inf)
0 commit comments