Skip to content

Commit 6df40b5

Browse files
authored
Merge branch 'main' into qqy/NEW_RMHMC
2 parents 567e2a8 + dc8dc1c commit 6df40b5

20 files changed

+113
-69
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- macOS-latest
3232
- windows-latest
3333
steps:
34-
- uses: actions/checkout@v4
34+
- uses: actions/checkout@v5
3535
- uses: julia-actions/setup-julia@v2
3636
with:
3737
version: ${{ matrix.version }}

.github/workflows/DocsNav.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717

1818
steps:
1919
- name: Checkout gh-pages branch
20-
uses: actions/checkout@v4
20+
uses: actions/checkout@v5
2121
with:
2222
ref: gh-pages
2323

.github/workflows/ExperimentalTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- macOS-latest
2929
- windows-latest
3030
steps:
31-
- uses: actions/checkout@v4
31+
- uses: actions/checkout@v5
3232
- uses: julia-actions/setup-julia@v2
3333
with:
3434
version: ${{ matrix.version }}

.github/workflows/Format.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
format:
1717
runs-on: ubuntu-latest
1818
steps:
19-
- uses: actions/checkout@v4
19+
- uses: actions/checkout@v5
2020
- uses: julia-actions/setup-julia@v2
2121
with:
2222
version: 1

.github/workflows/IntegrationTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- macOS-latest
2929
- windows-latest
3030
steps:
31-
- uses: actions/checkout@v4
31+
- uses: actions/checkout@v5
3232
- uses: julia-actions/setup-julia@v2
3333
with:
3434
version: ${{ matrix.version }}

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.8.0"
3+
version = "0.8.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -32,7 +32,7 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3232

3333
[compat]
3434
ADTypes = "1"
35-
AbstractMCMC = "5.6"
35+
AbstractMCMC = "5.9"
3636
ArgCheck = "1, 2"
3737
CUDA = "3, 4, 5"
3838
ComponentArrays = "0.15"

src/abstractmcmc.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ getintegrator(state::HMCState) = state.κ.τ.integrator
3333
function AbstractMCMC.getparams(state::HMCState)
3434
return state.transition.z.θ
3535
end
36+
AbstractMCMC.getstats(state::AdvancedHMC.HMCState) = state.transition.stat
3637

3738
function AbstractMCMC.setparams!!(
3839
model::AbstractMCMC.LogDensityModel, state::HMCState, params
@@ -55,8 +56,8 @@ function AbstractMCMC.sample(
5556
sampler::AbstractHMCSampler,
5657
N::Integer;
5758
n_adapts::Int=min(div(N, 10), 1_000),
58-
progress=true,
59-
verbose=false,
59+
progress::Bool=true,
60+
verbose::Bool=false,
6061
callback=nothing,
6162
kwargs...,
6263
)
@@ -94,8 +95,8 @@ function AbstractMCMC.sample(
9495
N::Integer,
9596
nchains::Integer;
9697
n_adapts::Int=min(div(N, 10), 1_000),
97-
progress=true,
98-
verbose=false,
98+
progress::Bool=true,
99+
verbose::Bool=false,
99100
callback=nothing,
100101
kwargs...,
101102
)
@@ -217,7 +218,7 @@ logging behavior of the non-AbstractMCMC [`sample`](@ref).
217218
# Fields
218219
$(FIELDS)
219220
"""
220-
struct HMCProgressCallback{P}
221+
struct HMCProgressCallback{P<:Union{ProgressMeter.Progress,Nothing}}
221222
"`Progress` meter from ProgressMeters.jl, or `nothing`."
222223
pm::P
223224
"If `pm === nothing` and this is `true` some information will be logged upon completion of adaptation."
@@ -227,12 +228,21 @@ struct HMCProgressCallback{P}
227228
num_divergent_transitions_during_adaption::Base.RefValue{Int}
228229
end
229230

230-
function HMCProgressCallback(n_samples; progress=true, verbose=false)
231+
function HMCProgressCallback(n_samples::Integer; progress::Bool=true, verbose::Bool=false)
231232
pm = progress ? ProgressMeter.Progress(n_samples; desc="Sampling", barlen=31) : nothing
232233
return HMCProgressCallback(pm, verbose, Ref(0), Ref(0))
233234
end
234235

235-
function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0, kwargs...)
236+
function (cb::HMCProgressCallback)(
237+
rng::AbstractRNG,
238+
model::AbstractMCMC.LogDensityModel,
239+
spl::AbstractHMCSampler,
240+
t::Transition,
241+
state::HMCState,
242+
i::Int;
243+
n_adapts::Int=0,
244+
kwargs...,
245+
)
236246
verbose = cb.verbose
237247
pm = cb.pm
238248

@@ -260,16 +270,18 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int=0
260270
# Do include current iteration and mass matrix
261271
pm_next!(
262272
pm,
263-
(
264-
iterations=i,
265-
ratio_divergent_transitions=round(
266-
percentage_divergent_transitions; digits=2
267-
),
268-
ratio_divergent_transitions_during_adaption=round(
269-
percentage_divergent_transitions_during_adaption; digits=2
273+
merge(
274+
(;
275+
iterations=i,
276+
ratio_divergent_transitions=round(
277+
percentage_divergent_transitions; digits=2
278+
),
279+
ratio_divergent_transitions_during_adaption=round(
280+
percentage_divergent_transitions_during_adaption; digits=2
281+
),
282+
mass_matrix=metric,
270283
),
271-
tstat...,
272-
mass_matrix=metric,
284+
tstat,
273285
),
274286
)
275287
# Report finish of adapation

src/adaptation/Adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAd
3737
pc::M
3838
ssa::Tssa
3939
end
40-
function Base.show(io::IO, ::MIME"text/plain", a::NaiveHMCAdaptor)
40+
function Base.show(io::IO, a::NaiveHMCAdaptor)
4141
return print(io, "NaiveHMCAdaptor(pc=", a.pc, ", ssa=", a.ssa, ")")
4242
end
4343

src/adaptation/massmatrix.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end
2525

26-
function Base.show(io::IO, mime::MIME"text/plain", ::UnitMassMatrix{T}) where {T}
26+
function Base.show(io::IO, ::UnitMassMatrix{T}) where {T}
2727
return print(io, "UnitMassMatrix{", T, "} adaptor")
2828
end
2929

@@ -93,7 +93,7 @@ mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVec
9393
end
9494
end
9595

96-
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordVar{T}) where {T}
96+
function Base.show(io::IO, ::WelfordVar{T}) where {T}
9797
return print(io, "WelfordVar{", T, "} adaptor")
9898
end
9999

@@ -194,7 +194,7 @@ mutable struct WelfordCov{F<:AbstractFloat,C<:AbstractMatrix{F}} <: DenseMatrixE
194194
cov::C
195195
end
196196

197-
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordCov{T}) where {T}
197+
function Base.show(io::IO, ::WelfordCov{T}) where {T}
198198
return print(io, "WelfordCov{", T, "} adaptor")
199199
end
200200

src/adaptation/stan_adaptor.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function initialize!(
4949
return nothing
5050
end
5151

52-
function Base.show(io::IO, mime::MIME"text/plain", state::StanHMCAdaptorState)
52+
function Base.show(io::IO, state::StanHMCAdaptorState)
5353
print(io, "window(", state.window_start, ", ", state.window_end, "), window_splits(")
5454
join(io, state.window_splits, ", ")
5555
return print(io, ")")
@@ -66,6 +66,23 @@ struct StanHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAda
6666
window_size::Int
6767
state::StanHMCAdaptorState
6868
end
69+
70+
function Base.show(io::IO, a::StanHMCAdaptor)
71+
return print(
72+
io,
73+
"StanHMCAdaptor(",
74+
a.pc,
75+
", ",
76+
a.ssa,
77+
"; init_buffer=",
78+
a.init_buffer,
79+
", term_buffer=",
80+
a.term_buffer,
81+
", window_size=",
82+
a.window_size,
83+
")",
84+
)
85+
end
6986
function Base.show(io::IO, mime::MIME"text/plain", a::StanHMCAdaptor)
7087
return print(
7188
io,

0 commit comments

Comments
 (0)