You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -158,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat}
158
136
returnnothing
159
137
end
160
138
161
-
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T}
139
+
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat}
162
140
wv.n +=1
163
141
(; δ, μ, M, n) = wv
164
142
n =T(n)
@@ -176,8 +154,13 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
176
154
return n / ((n +5) * (n -1)) * M .+ ϵ * (5/ (n +5))
177
155
end
178
156
179
-
## Nutpie-style diagonal mass matrix estimator (using positions and gradients) - not exported yet due to https://github.com/TuringLang/AdvancedHMC.jl/issues/475
157
+
"""
158
+
Nutpie-style diagonal mass matrix estimator (using positions and gradients) - not exported yet due to https://github.com/TuringLang/AdvancedHMC.jl/issues/475
180
159
160
+
Expected to converge faster and to a better mass matrix than WelfordVar.
161
+
162
+
Can be initialized via NutpieVar(sz) where sz is either a `Tuple{Int}` or a `Tuple{Int,Int}`.
Computes the condition number of a covariance matrix `cov::AbstractMatrix` after preconditioning with the (diagonal) mass matrix estimated in `a::DiagMatrixEstimator`.
45
+
46
+
This is a simple but serviceable proxy for eventual sampling efficiency, but see also https://arxiv.org/abs/1905.09813 for a more involved estimate.
47
+
48
+
(A lower number generally means that the estimated mass matrix is better).
0 commit comments