@@ -33,6 +33,7 @@ getintegrator(state::HMCState) = state.κ.τ.integrator
3333function AbstractMCMC. getparams (state:: HMCState )
3434 return state. transition. z. θ
3535end
36+ AbstractMCMC. getstats (state:: AdvancedHMC.HMCState ) = state. transition. stat
3637
3738function AbstractMCMC. setparams!! (
3839 model:: AbstractMCMC.LogDensityModel , state:: HMCState , params
@@ -55,8 +56,8 @@ function AbstractMCMC.sample(
5556 sampler:: AbstractHMCSampler ,
5657 N:: Integer ;
5758 n_adapts:: Int = min (div (N, 10 ), 1_000 ),
58- progress= true ,
59- verbose= false ,
59+ progress:: Bool = true ,
60+ verbose:: Bool = false ,
6061 callback= nothing ,
6162 kwargs... ,
6263)
@@ -94,8 +95,8 @@ function AbstractMCMC.sample(
9495 N:: Integer ,
9596 nchains:: Integer ;
9697 n_adapts:: Int = min (div (N, 10 ), 1_000 ),
97- progress= true ,
98- verbose= false ,
98+ progress:: Bool = true ,
99+ verbose:: Bool = false ,
99100 callback= nothing ,
100101 kwargs... ,
101102)
@@ -217,7 +218,7 @@ logging behavior of the non-AbstractMCMC [`sample`](@ref).
217218# Fields
218219$(FIELDS)
219220"""
220- struct HMCProgressCallback{P}
221+ struct HMCProgressCallback{P<: Union{ProgressMeter.Progress,Nothing} }
221222 " `Progress` meter from ProgressMeters.jl, or `nothing`."
222223 pm:: P
223224 " If `pm === nothing` and this is `true` some information will be logged upon completion of adaptation."
@@ -227,12 +228,21 @@ struct HMCProgressCallback{P}
227228 num_divergent_transitions_during_adaption:: Base.RefValue{Int}
228229end
229230
230- function HMCProgressCallback (n_samples; progress= true , verbose= false )
231+ function HMCProgressCallback (n_samples:: Integer ; progress:: Bool = true , verbose:: Bool = false )
231232 pm = progress ? ProgressMeter. Progress (n_samples; desc= " Sampling" , barlen= 31 ) : nothing
232233 return HMCProgressCallback (pm, verbose, Ref (0 ), Ref (0 ))
233234end
234235
235- function (cb:: HMCProgressCallback )(rng, model, spl, t, state, i; n_adapts:: Int = 0 , kwargs... )
236+ function (cb:: HMCProgressCallback )(
237+ rng:: AbstractRNG ,
238+ model:: AbstractMCMC.LogDensityModel ,
239+ spl:: AbstractHMCSampler ,
240+ t:: Transition ,
241+ state:: HMCState ,
242+ i:: Int ;
243+ n_adapts:: Int = 0 ,
244+ kwargs... ,
245+ )
236246 verbose = cb. verbose
237247 pm = cb. pm
238248
@@ -260,16 +270,18 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0
260270 # Do include current iteration and mass matrix
261271 pm_next! (
262272 pm,
263- (
264- iterations= i,
265- ratio_divergent_transitions= round (
266- percentage_divergent_transitions; digits= 2
267- ),
268- ratio_divergent_transitions_during_adaption= round (
269- percentage_divergent_transitions_during_adaption; digits= 2
273+ merge (
274+ (;
275+ iterations= i,
276+ ratio_divergent_transitions= round (
277+ percentage_divergent_transitions; digits= 2
278+ ),
279+ ratio_divergent_transitions_during_adaption= round (
280+ percentage_divergent_transitions_during_adaption; digits= 2
281+ ),
282+ mass_matrix= metric,
270283 ),
271- tstat... ,
272- mass_matrix= metric,
284+ tstat,
273285 ),
274286 )
275287 # Report finish of adapation
0 commit comments