Skip to content

Commit 76a1373

Browse files
committed
improve docstring for NutpieVar
1 parent f1d1c80 commit 76a1373

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/adaptation/massmatrix.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,28 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
155155
end
156156

157157
"""
158+
NutpieVar
159+
158160
Nutpie-style diagonal mass matrix estimator (using positions and gradients) - not exported yet due to https://github.com/TuringLang/AdvancedHMC.jl/issues/475
159161
160-
Expected to converge faster and to a better mass matrix than WelfordVar.
162+
Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement.
163+
164+
Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`.
161165
162-
Can be initialized via NutpieVar(sz) where sz is either a `Tuple{Int}` or a `Tuple{Int,Int}`.
166+
# Fields
167+
168+
$(FIELDS)
163169
"""
164170
mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
171+
"Online variance estimator of the posterior positions."
165172
position_estimator::WelfordVar{T,E,V}
173+
"Online variance estimator of the posterior gradients."
166174
gradient_estimator::WelfordVar{T,E,V}
175+
"The number of observations collected so far."
167176
n::Int
177+
"The minimal number of observations after which the estimate of the variances can be updated."
168178
n_min::Int
179+
"The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`."
169180
var::V
170181
function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V}
171182
return new{eltype(E),E,V}(
@@ -225,9 +236,7 @@ function Base.push!(nv::NutpieVar, z::PhasePoint)
225236
end
226237

227238
# Ref: https://github.com/pymc-devs/nutpie
228-
function get_estimation(nv::NutpieVar)
229-
return sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator))
230-
end
239+
get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator))
231240

232241
## Dense mass matrix adaptor
233242

0 commit comments

Comments
 (0)