Skip to content

Commit 2f5b6c1

Browse files
authored
Fix JET errors (#468)
1 parent e737660 commit 2f5b6c1

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

src/abstractmcmc.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ function AbstractMCMC.sample(
5555
sampler::AbstractHMCSampler,
5656
N::Integer;
5757
n_adapts::Int=min(div(N, 10), 1_000),
58-
progress=true,
59-
verbose=false,
58+
progress::Bool=true,
59+
verbose::Bool=false,
6060
callback=nothing,
6161
kwargs...,
6262
)
@@ -94,8 +94,8 @@ function AbstractMCMC.sample(
9494
N::Integer,
9595
nchains::Integer;
9696
n_adapts::Int=min(div(N, 10), 1_000),
97-
progress=true,
98-
verbose=false,
97+
progress::Bool=true,
98+
verbose::Bool=false,
9999
callback=nothing,
100100
kwargs...,
101101
)
@@ -217,7 +217,7 @@ logging behavior of the non-AbstractMCMC [`sample`](@ref).
217217
# Fields
218218
$(FIELDS)
219219
"""
220-
struct HMCProgressCallback{P}
220+
struct HMCProgressCallback{P<:Union{ProgressMeter.Progress,Nothing}}
221221
"`Progress` meter from ProgressMeters.jl, or `nothing`."
222222
pm::P
223223
"If `pm === nothing` and this is `true` some information will be logged upon completion of adaptation."
@@ -227,12 +227,21 @@ struct HMCProgressCallback{P}
227227
num_divergent_transitions_during_adaption::Base.RefValue{Int}
228228
end
229229

230-
function HMCProgressCallback(n_samples; progress=true, verbose=false)
230+
function HMCProgressCallback(n_samples::Integer; progress::Bool=true, verbose::Bool=false)
231231
pm = progress ? ProgressMeter.Progress(n_samples; desc="Sampling", barlen=31) : nothing
232232
return HMCProgressCallback(pm, verbose, Ref(0), Ref(0))
233233
end
234234

235-
function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0, kwargs...)
235+
function (cb::HMCProgressCallback)(
236+
rng::AbstractRNG,
237+
model::AbstractMCMC.LogDensityModel,
238+
spl::AbstractHMCSampler,
239+
t::Transition,
240+
state::HMCState,
241+
i::Int;
242+
n_adapts::Int=0,
243+
kwargs...,
244+
)
236245
verbose = cb.verbose
237246
pm = cb.pm
238247

@@ -260,16 +269,18 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0
260269
# Do include current iteration and mass matrix
261270
pm_next!(
262271
pm,
263-
(
264-
iterations=i,
265-
ratio_divergent_transitions=round(
266-
percentage_divergent_transitions; digits=2
267-
),
268-
ratio_divergent_transitions_during_adaption=round(
269-
percentage_divergent_transitions_during_adaption; digits=2
272+
merge(
273+
(;
274+
iterations=i,
275+
ratio_divergent_transitions=round(
276+
percentage_divergent_transitions; digits=2
277+
),
278+
ratio_divergent_transitions_during_adaption=round(
279+
percentage_divergent_transitions_during_adaption; digits=2
280+
),
281+
mass_matrix=metric,
270282
),
271-
tstat...,
272-
mass_matrix=metric,
283+
tstat,
273284
),
274285
)
275286
# Report finish of adapation

src/sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,15 @@ end
9393
"""
9494
Progress meter update with all trajectory stats, iteration number and metric shown.
9595
"""
96-
function pm_next!(pm, stat::NamedTuple)
96+
function pm_next!(pm::ProgressMeter.Progress, stat::NamedTuple)
9797
ProgressMeter.next!(pm; showvalues=map(tuple, values(stat), keys(stat)))
9898
return nothing
9999
end
100100

101101
"""
102102
Simple progress meter update without any show values.
103103
"""
104-
simple_pm_next!(pm, stat::NamedTuple) = ProgressMeter.next!(pm)
104+
simple_pm_next!(pm::ProgressMeter.Progress, ::NamedTuple) = ProgressMeter.next!(pm)
105105

106106
##
107107
## Sampling functions

0 commit comments

Comments
 (0)