@@ -55,8 +55,8 @@ function AbstractMCMC.sample(
5555 sampler:: AbstractHMCSampler ,
5656 N:: Integer ;
5757 n_adapts:: Int = min (div (N, 10 ), 1_000 ),
58- progress= true ,
59- verbose= false ,
58+ progress:: Bool = true ,
59+ verbose:: Bool = false ,
6060 callback= nothing ,
6161 kwargs... ,
6262)
@@ -94,8 +94,8 @@ function AbstractMCMC.sample(
9494 N:: Integer ,
9595 nchains:: Integer ;
9696 n_adapts:: Int = min (div (N, 10 ), 1_000 ),
97- progress= true ,
98- verbose= false ,
97+ progress:: Bool = true ,
98+ verbose:: Bool = false ,
9999 callback= nothing ,
100100 kwargs... ,
101101)
@@ -217,7 +217,7 @@ logging behavior of the non-AbstractMCMC [`sample`](@ref).
217217# Fields
218218$(FIELDS)
219219"""
220- struct HMCProgressCallback{P}
220+ struct HMCProgressCallback{P<: Union{ProgressMeter.Progress,Nothing} }
221221 " `Progress` meter from ProgressMeters.jl, or `nothing`."
222222 pm:: P
223223 " If `pm === nothing` and this is `true` some information will be logged upon completion of adaptation."
@@ -227,12 +227,21 @@ struct HMCProgressCallback{P}
227227 num_divergent_transitions_during_adaption:: Base.RefValue{Int}
228228end
229229
230- function HMCProgressCallback (n_samples; progress= true , verbose= false )
230+ function HMCProgressCallback (n_samples:: Integer ; progress:: Bool = true , verbose:: Bool = false )
231231 pm = progress ? ProgressMeter. Progress (n_samples; desc= " Sampling" , barlen= 31 ) : nothing
232232 return HMCProgressCallback (pm, verbose, Ref (0 ), Ref (0 ))
233233end
234234
235- function (cb:: HMCProgressCallback )(rng, model, spl, t, state, i; n_adapts:: Int = 0 , kwargs... )
235+ function (cb:: HMCProgressCallback )(
236+ rng:: AbstractRNG ,
237+ model:: AbstractMCMC.LogDensityModel ,
238+ spl:: AbstractHMCSampler ,
239+ t:: Transition ,
240+ state:: HMCState ,
241+ i:: Int ;
242+ n_adapts:: Int = 0 ,
243+ kwargs... ,
244+ )
236245 verbose = cb. verbose
237246 pm = cb. pm
238247
@@ -260,16 +269,18 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0
260269 # Do include current iteration and mass matrix
261270 pm_next! (
262271 pm,
263- (
264- iterations= i,
265- ratio_divergent_transitions= round (
266- percentage_divergent_transitions; digits= 2
267- ),
268- ratio_divergent_transitions_during_adaption= round (
269- percentage_divergent_transitions_during_adaption; digits= 2
272+ merge (
273+ (;
274+ iterations= i,
275+ ratio_divergent_transitions= round (
276+ percentage_divergent_transitions; digits= 2
277+ ),
278+ ratio_divergent_transitions_during_adaption= round (
279+ percentage_divergent_transitions_during_adaption; digits= 2
280+ ),
281+ mass_matrix= metric,
270282 ),
271- tstat... ,
272- mass_matrix= metric,
283+ tstat,
273284 ),
274285 )
275286 # Report finish of adapation
0 commit comments