Skip to content

Commit f1d1c80

Browse files
committed
refactor and fix test rng handling
1 parent c395715 commit f1d1c80

File tree

7 files changed

+62
-120
lines changed

7 files changed

+62
-120
lines changed

src/AdvancedHMC.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export find_good_eps
7272
include("adaptation/Adaptation.jl")
7373
using .Adaptation
7474
import .Adaptation:
75-
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation
75+
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint
7676

7777
# Helpers for initializing adaptors via AHMC structs
7878

@@ -114,6 +114,7 @@ export StepSizeAdaptor,
114114
MassMatrixAdaptor,
115115
UnitMassMatrix,
116116
WelfordVar,
117+
NutpieVar,
117118
WelfordCov,
118119
NaiveHMCAdaptor,
119120
StanHMCAdaptor,

src/adaptation/Adaptation.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using DocStringExtensions
1010
"""
1111
$(TYPEDEF)
1212
13-
Abstract type for HMC adaptors.
13+
Abstract type for HMC adaptors.
1414
"""
1515
abstract type AbstractAdaptor end
1616
function getM⁻¹ end
@@ -21,12 +21,17 @@ function initialize! end
2121
function finalize! end
2222
export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹
2323

24+
get_position(x::PhasePoint) = x.θ
25+
get_position(x::AbstractVecOrMat{<:AbstractFloat}) = x
26+
const PositionOrPhasePoint = Union{AbstractVecOrMat{<:AbstractFloat}, PhasePoint}
27+
2428
struct NoAdaptation <: AbstractAdaptor end
2529
export NoAdaptation
2630
include("stepsize.jl")
2731
export StepSizeAdaptor, NesterovDualAveraging
32+
2833
include("massmatrix.jl")
29-
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov
34+
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, NutpieVar, WelfordCov
3035

3136
##
3237
## Composite adaptors
@@ -47,23 +52,14 @@ getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa)
4752
# TODO: implement consensus adaptor
4853
function adapt!(
4954
nca::NaiveHMCAdaptor,
50-
θ::AbstractVecOrMat{<:AbstractFloat},
55+
z_or_theta::PositionOrPhasePoint,
5156
α::AbstractScalarOrVec{<:AbstractFloat},
5257
)
53-
adapt!(nca.ssa, θ, α)
54-
adapt!(nca.pc, θ, α)
55-
return nothing
56-
end
57-
adapt!(
58-
nca::NaiveHMCAdaptor,
59-
z::PhasePoint,
60-
α::AbstractScalarOrVec{<:AbstractFloat},
61-
) = adapt!(nca, z.θ, α)
62-
function reset!(aca::NaiveHMCAdaptor)
63-
reset!(aca.ssa)
64-
reset!(aca.pc)
58+
adapt!(nca.ssa, z_or_theta, α)
59+
adapt!(nca.pc, z_or_theta, α)
6560
return nothing
6661
end
62+
6763
initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing
6864
finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)
6965

src/adaptation/massmatrix.jl

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,17 @@ finalize!(::MassMatrixAdaptor) = nothing
99

1010
function adapt!(
1111
adaptor::MassMatrixAdaptor,
12-
θ::AbstractVecOrMat{<:AbstractFloat},
13-
α::AbstractScalarOrVec{<:AbstractFloat},
14-
is_update::Bool=true,
15-
)
16-
resize_adaptor!(adaptor, size(θ))
17-
push!(adaptor, θ)
18-
is_update && update!(adaptor)
19-
return nothing
20-
end
21-
22-
function adapt!(
23-
adaptor::MassMatrixAdaptor,
24-
z::PhasePoint,
25-
α::AbstractScalarOrVec{<:AbstractFloat},
12+
z_or_theta::PositionOrPhasePoint,
13+
::AbstractScalarOrVec{<:AbstractFloat},
2614
is_update::Bool=true,
2715
)
28-
resize_adaptor!(adaptor, size(z.θ))
29-
push!(adaptor, z)
16+
resize_adaptor!(adaptor, size(get_position(z_or_theta)))
17+
push!(adaptor, z_or_theta)
3018
is_update && update!(adaptor)
3119
return nothing
3220
end
3321

34-
Base.push!(a::MassMatrixAdaptor, z::PhasePoint) = push!(a, z.θ)
22+
Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta))
3523

3624
## Unit mass matrix adaptor
3725

@@ -53,24 +41,14 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T
5341

5442
function adapt!(
5543
::UnitMassMatrix,
56-
::AbstractVecOrMat{<:AbstractFloat},
57-
::AbstractScalarOrVec{<:AbstractFloat},
58-
is_update::Bool=true,
59-
)
60-
return nothing
61-
end
62-
63-
function adapt!(
64-
::UnitMassMatrix,
65-
::PhasePoint,
44+
::PositionOrPhasePoint,
6645
::AbstractScalarOrVec{<:AbstractFloat},
6746
is_update::Bool=true,
6847
)
6948
return nothing
7049
end
7150

7251
## Diagonal mass matrix adaptor
73-
7452
abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end
7553

7654
getM⁻¹(ve::DiagMatrixEstimator) = ve.var
@@ -93,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri
9371

9472
NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz)
9573

96-
Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s)
74+
Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s)
9775

9876
reset!(nv::NaiveVar) = resize!(nv.S, 0)
9977

@@ -158,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat}
158136
return nothing
159137
end
160138

161-
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T}
139+
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat}
162140
wv.n += 1
163141
(; δ, μ, M, n) = wv
164142
n = T(n)
@@ -176,8 +154,13 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
176154
return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5))
177155
end
178156

179-
## Nutpie-style diagonal mass matrix estimator (using positions and gradients) - not exported yet due to https://github.com/TuringLang/AdvancedHMC.jl/issues/475
157+
"""
158+
Nutpie-style diagonal mass matrix estimator (using positions and gradients) - not exported yet due to https://github.com/TuringLang/AdvancedHMC.jl/issues/475
180159
160+
Expected to converge faster and to a better mass matrix than WelfordVar.
161+
162+
Can be initialized via NutpieVar(sz) where sz is either a `Tuple{Int}` or a `Tuple{Int,Int}`.
163+
"""
181164
mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
182165
position_estimator::WelfordVar{T,E,V}
183166
gradient_estimator::WelfordVar{T,E,V}
@@ -232,6 +215,8 @@ function reset!(nv::NutpieVar)
232215
reset!(nv.gradient_estimator)
233216
end
234217

218+
Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!")
219+
235220
function Base.push!(nv::NutpieVar, z::PhasePoint)
236221
nv.n += 1
237222
push!(nv.position_estimator, z.θ)
@@ -266,7 +251,7 @@ end
266251

267252
NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}())
268253

269-
Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s)
254+
Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s)
270255

271256
reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0)
272257

@@ -316,7 +301,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat}
316301
return nothing
317302
end
318303

319-
function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T}
304+
function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat}
320305
wc.n += 1
321306
(; δ, μ, n, M) = wc
322307
n = T(n)

src/adaptation/stan_adaptor.jl

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -136,45 +136,20 @@ is_window_end(a::StanHMCAdaptor) = a.state.i in a.state.window_splits
136136

137137
function adapt!(
138138
tp::StanHMCAdaptor,
139-
θ::AbstractVecOrMat{<:AbstractFloat},
139+
z_or_theta::PositionOrPhasePoint,
140140
α::AbstractScalarOrVec{<:AbstractFloat},
141141
)
142142
tp.state.i += 1
143143

144-
adapt!(tp.ssa, θ, α)
144+
adapt!(tp.ssa, z_or_theta, α)
145145

146-
resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary.
146+
resize_adaptor!(tp.pc, size(get_position(z_or_theta))) # Resize pre-conditioner if necessary.
147147

148148
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
149149
if is_in_window(tp)
150150
# We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window.
151151
is_update_M⁻¹ = is_window_end(tp)
152-
adapt!(tp.pc, θ, α, is_update_M⁻¹)
153-
end
154-
155-
if is_window_end(tp)
156-
reset!(tp.ssa)
157-
reset!(tp.pc)
158-
end
159-
end
160-
161-
162-
function adapt!(
163-
tp::StanHMCAdaptor,
164-
z::PhasePoint,
165-
α::AbstractScalarOrVec{<:AbstractFloat},
166-
)
167-
tp.state.i += 1
168-
169-
adapt!(tp.ssa, z.θ, α)
170-
171-
resize_adaptor!(tp.pc, size(z.θ)) # Resize pre-conditioner if necessary.
172-
173-
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
174-
if is_in_window(tp)
175-
# We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window.
176-
is_update_M⁻¹ = is_window_end(tp)
177-
adapt!(tp.pc, z, α, is_update_M⁻¹)
152+
adapt!(tp.pc, z_or_theta, α, is_update_M⁻¹)
178153
end
179154

180155
if is_window_end(tp)

src/adaptation/stepsize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ end
174174
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
175175
# Note: This function is not merged with `adapt!` to empahsize the fact that
176176
# step size adaptation is not dependent on `θ`.
177-
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
177+
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
178178
function adapt_stepsize!(
179179
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T}
180180
) where {T<:AbstractFloat}
@@ -211,7 +211,7 @@ end
211211

212212
function adapt!(
213213
da::NesterovDualAveraging,
214-
θ::AbstractVecOrMat{<:AbstractFloat},
214+
::PositionOrPhasePoint,
215215
α::AbstractScalarOrVec{<:AbstractFloat},
216216
)
217217
adapt_stepsize!(da, α)

src/sampler.jl

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ end
6060
function Adaptation.adapt!(
6161
h::Hamiltonian,
6262
κ::AbstractMCMCKernel,
63-
adaptor::Adaptation.NoAdaptation,
64-
i::Int,
65-
n_adapts::Int,
66-
θ::AbstractVecOrMat{<:AbstractFloat},
67-
α::AbstractScalarOrVec{<:AbstractFloat},
63+
::Adaptation.NoAdaptation,
64+
::Int,
65+
::Int,
66+
::PositionOrPhasePoint,
67+
::AbstractScalarOrVec{<:AbstractFloat},
6868
)
6969
return h, κ, false
7070
end
@@ -75,40 +75,18 @@ function Adaptation.adapt!(
7575
adaptor::AbstractAdaptor,
7676
i::Int,
7777
n_adapts::Int,
78-
θ::AbstractVecOrMat{<:AbstractFloat},
79-
α::AbstractScalarOrVec{<:AbstractFloat},
80-
)
81-
isadapted = false
82-
if i <= n_adapts
83-
i == 1 && Adaptation.initialize!(adaptor, n_adapts)
84-
adapt!(adaptor, θ, α)
85-
i == n_adapts && finalize!(adaptor)
86-
h = update(h, adaptor)
87-
κ = update(κ, adaptor)
88-
isadapted = true
89-
end
90-
return h, κ, isadapted
91-
end
92-
93-
function Adaptation.adapt!(
94-
h::Hamiltonian,
95-
κ::AbstractMCMCKernel,
96-
adaptor::AbstractAdaptor,
97-
i::Int,
98-
n_adapts::Int,
99-
z::PhasePoint,
78+
z_or_theta::PositionOrPhasePoint,
10079
α::AbstractScalarOrVec{<:AbstractFloat},
10180
)
102-
isadapted = false
103-
if i <= n_adapts
81+
adapt = i <= n_adapts
82+
if adapt
10483
i == 1 && Adaptation.initialize!(adaptor, n_adapts)
105-
adapt!(adaptor, z, α)
84+
adapt!(adaptor, z_or_theta, α)
10685
i == n_adapts && finalize!(adaptor)
10786
h = update(h, adaptor)
10887
κ = update(κ, adaptor)
109-
isadapted = true
11088
end
111-
return h, κ, isadapted
89+
return h, κ, adapt
11290
end
11391

11492
"""
@@ -169,7 +147,7 @@ end
169147
progress::Bool=false
170148
)
171149
Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
172-
- The randomness is controlled by `rng`.
150+
- The randomness is controlled by `rng`.
173151
- If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used.
174152
- The initial point is given by `θ`.
175153
- The adaptor is set by `adaptor`, for which the default is no adaptation.

test/adaptation.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@ function runnuts_nutpie(ℓπ, metric::DiagEuclideanMetric; n_samples=10_000)
3434
κ = AdvancedHMC.make_kernel(nuts, integrator)
3535
# Constructing like this until we've settled on a different interface
3636
adaptor = AdvancedHMC.StanHMCAdaptor(
37-
AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)),
37+
AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)),
3838
AdvancedHMC.StepSizeAdaptor(nuts.δ, integrator)
3939
)
4040
samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose=false)
4141
return (samples=samples, stats=stats, adaptor=adaptor)
4242
end
43+
"""
44+
Computes the condition number of a covariance matrix `cov::AbstractMatrix` after preconditioning with the (diagonal) mass matrix estimated in `a::DiagMatrixEstimator`.
45+
46+
This is a simple but serviceable proxy for eventual sampling efficiency, but see also https://arxiv.org/abs/1905.09813 for a more involved estimate.
47+
48+
(A lower number generally means that the estimated mass matrix is better).
49+
"""
4350
preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Diagonal(a.var)) \ cov / sqrt(Diagonal(a.var)))
4451

4552
@testset "Adaptation" begin
53+
Random.seed!(1)
4654
# Check that the estimated variance is approximately correct.
4755
@testset "Online v.s. naive v.s. true var/cov estimation" begin
4856
D = 10
@@ -159,9 +167,8 @@ preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Dia
159167
@testset "Adapted mass v.s. true variance" begin
160168
D = 10
161169
n_tests = 5
162-
@testset "DiagEuclideanMetric" begin
170+
@testset "'Diagonal' MvNormal target" begin
163171
for _ in 1:n_tests
164-
Random.seed!(1)
165172

166173
# Random variance
167174
σ² = 1 .+ abs.(randn(D))
@@ -183,7 +190,7 @@ preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Dia
183190
end
184191
end
185192

186-
@testset "DenseEuclideanMetric" begin
193+
@testset "'Dense' MvNormal target" begin
187194
n_nutpie_superior = 0
188195
for _ in 1:n_tests
189196
# Random covariance
@@ -197,16 +204,16 @@ preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Dia
197204
@test res.adaptor.pc.var diag(Σ) rtol = 0.2
198205

199206
# For this target, Nutpie will NOT converge towards the true variances, even after infinite draws.
200-
# HOWEVER, it will asymptotically (but also generally more quickly than Stan)
207+
# HOWEVER, it will asymptotically (but also generally more quickly than Stan)
201208
# find the best preconditioner for the target.
202-
# As these are statistical algorithms, superiority is not always guaranteed, hence this way of testing.
209+
# As these are statistical algorithms, superiority is not always guaranteed, hence this way of testing.
203210
res_nutpie = runnuts_nutpie(ℓπ, DiagEuclideanMetric(D))
204211
n_nutpie_superior += preconditioned_cond(res_nutpie.adaptor.pc, Σ) < preconditioned_cond(res.adaptor.pc, Σ)
205212

206213
res = runnuts(ℓπ, DenseEuclideanMetric(D))
207214
@test res.adaptor.pc.cov Σ rtol = 0.25
208215
end
209-
@test n_nutpie_superior > n_tests / 2
216+
@test n_nutpie_superior > 1 + n_tests / 2
210217
end
211218
end
212219

0 commit comments

Comments
 (0)