Skip to content

Commit 9cac513

Browse files
committed
Merge remote-tracking branch 'origin/qqy/NEW_RMHMC' into Jamie_RHMC
2 parents 981932e + c9e6b0a commit 9cac513

File tree

9 files changed

+138
-22
lines changed

9 files changed

+138
-22
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.8.2"
3+
version = "0.8.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -32,7 +32,7 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3232

3333
[compat]
3434
ADTypes = "1"
35-
AbstractMCMC = "5.6"
35+
AbstractMCMC = "5.9"
3636
ArgCheck = "1, 2"
3737
CUDA = "3, 4, 5"
3838
ComponentArrays = "0.15"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c
88
- Unit metric: `UnitEuclideanMetric(dim)`
99
- Diagonal metric: `DiagEuclideanMetric(dim)`
1010
- Dense metric: `DenseEuclideanMetric(dim)`
11+
- Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)`
1112

1213
where `dim` is the dimensionality of the sampling space.
1314

src/AdvancedHMC.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,19 @@ module AdvancedHMC
22

33
using Statistics: mean, var, middle
44
using LinearAlgebra:
5-
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
5+
Symmetric,
6+
UpperTriangular,
7+
mul!,
8+
ldiv!,
9+
dot,
10+
I,
11+
diag,
12+
cholesky,
13+
UniformScaling,
14+
logdet,
15+
tr,
16+
eigen,
17+
diagm
618
using StatsFuns: logaddexp, logsumexp, loghalf
719
using Random: Random, AbstractRNG
820
using ProgressMeter: ProgressMeter
@@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end
4052
export GaussianKinetic
4153

4254
include("metric.jl")
43-
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
55+
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric
4456

4557
include("hamiltonian.jl")
4658
export Hamiltonian
@@ -54,6 +66,11 @@ include("riemannian/integrator.jl")
5466
export GeneralizedLeapfrog, ImplicitMidpoint
5567
include("riemannian/hamiltonian.jl")
5668

69+
include("riemannian/metric.jl")
70+
export IdentityMap, SoftAbsMap, DenseRiemannianMetric
71+
72+
include("riemannian/hamiltonian.jl")
73+
5774
include("trajectory.jl")
5875
export Trajectory,
5976
HMCKernel,

src/abstractmcmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ getintegrator(state::HMCState) = state.κ.τ.integrator
3333
function AbstractMCMC.getparams(state::HMCState)
3434
return state.transition.z.θ
3535
end
36+
AbstractMCMC.getstats(state::AdvancedHMC.HMCState) = state.transition.stat
3637

3738
function AbstractMCMC.setparams!!(
3839
model::AbstractMCMC.LogDensityModel, state::HMCState, params

src/riemannian/hamiltonian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565

6666
# Position gradient with Riemannian correction terms
6767
function ∂H∂θ(
68-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}},
68+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
6969
θ::AbstractVecOrMat{T},
7070
r::AbstractVecOrMat{T},
7171
) where {T}
@@ -107,15 +107,15 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat}
107107
end
108108

109109
function ∂H∂θ(
110-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
110+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
111111
θ::AbstractVecOrMat{T},
112112
r::AbstractVecOrMat{T},
113113
) where {T}
114114
return ∂H∂θ_cache(h, θ, r)
115115
end
116116

117117
function ∂H∂θ_cache(
118-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
118+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
119119
θ::AbstractVecOrMat{T},
120120
r::AbstractVecOrMat{T};
121121
return_cache=false,

src/trajectory.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ function transition(
292292
hamiltonian_energy=H,
293293
hamiltonian_energy_error=H - H0,
294294
# check numerical error in proposed phase point.
295-
numerical_error=!all(isfinite, H′),
295+
numerical_error=(!all(isfinite, H′)),
296296
),
297297
stat.integrator),
298298
)
@@ -727,7 +727,7 @@ function transition(
727727
(
728728
n_steps=tree.nα,
729729
is_accept=true,
730-
acceptance_rate=tree.sum_α / tree.nα,
730+
acceptance_rate=(tree.sum_α / tree.),
731731
log_density=zcand.ℓπ.value,
732732
hamiltonian_energy=H,
733733
hamiltonian_energy_error=H - H0,

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
66
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
9+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1314
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1415
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
16+
MCMCLogDensityProblems = "8a639fad-7908-4fe4-8003-906e9297f002"
1517
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1618
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1719
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/riemannian.jl

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,63 @@
1-
using ReTest, AdvancedHMC
2-
3-
include("../src/riemannian_hmc.jl")
4-
include("../src/riemannian_hmc_utility.jl")
5-
1+
using ReTest, Random
2+
using AdvancedHMC, ForwardDiff, AbstractMCMC
3+
using LinearAlgebra
4+
using MCMCLogDensityProblems
65
using FiniteDiff:
76
finite_difference_gradient, finite_difference_hessian, finite_difference_jacobian
8-
using Distributions: MvNormal
9-
using AdvancedHMC: neg_energy, energy
7+
using AdvancedHMC: neg_energy, energy, ∂H∂θ, ∂H∂r
8+
9+
# Fisher information metric
10+
function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
11+
_Hfunc = gen_hess_fwd(Vfunc, x)
12+
Hfunc = x -> _Hfunc(x)[3]
13+
# QUES What's the best output format of this function?
14+
cfg = ForwardDiff.JacobianConfig(Hfunc, x)
15+
d = length(x)
16+
out = zeros(eltype(x), d^2, d)
17+
return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg)
18+
return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
19+
end
20+
21+
function gen_hess_fwd(func, x::AbstractVector)
22+
function hess(x::AbstractVector)
23+
return nothing, nothing, ForwardDiff.hessian(func, x)
24+
end
25+
return hess
26+
end
27+
28+
function reshape_∂G∂θ(H)
29+
d = size(H, 2)
30+
return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3)
31+
end
1032

11-
# Taken from https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/test/finitedifftests.jl
12-
δ(a, b) = maximum(abs.(a - b))
33+
function prepare_sample(ℓπ, initial_θ, λ)
34+
Vfunc = x -> -ℓπ(x)
35+
_Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) # x -> (value, gradient, hessian)
36+
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
1337

14-
@testset "Riemannian" begin
15-
hps = (; λ=1e-2, α=20.0, ϵ=0.1, n=6, L=8)
38+
fstabilize = H -> H + λ * I
39+
Gfunc = x -> begin
40+
H = fstabilize(Hfunc(x)[3])
41+
all(isfinite, H) ? H : diagm(ones(length(x)))
42+
end
43+
_∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize)
44+
∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x))
45+
46+
return Vfunc, Hfunc, Gfunc, ∂G∂θfunc
47+
end
1648

49+
@testset "Constructors tests" begin
50+
δ(a, b) = maximum(abs.(a - b))
1751
@testset "$(nameof(typeof(target)))" for target in [HighDimGaussian(2), Funnel()]
1852
rng = MersenneTwister(1110)
53+
λ = 1e-2
1954

2055
θ₀ = rand(rng, dim(target))
2156

2257
ℓπ = MCMCLogDensityProblems.gen_logpdf(target)
2358
∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀)
2459

25-
Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ)
60+
Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample(ℓπ, θ₀, λ)
2661

2762
D = dim(target) # ==2 for this test
2863
x = zeros(D) # randn(rng, D)
@@ -36,7 +71,7 @@ using AdvancedHMC: neg_energy, energy
3671
end
3772

3873
@testset "$(nameof(typeof(hessmap)))" for hessmap in
39-
[IdentityMap(), SoftAbsMap(hps.α)]
74+
[IdentityMap(), SoftAbsMap(20.0)]
4075
metric = DenseRiemannianMetric((D,), Gfunc, ∂G∂θfunc, hessmap)
4176
kinetic = GaussianKinetic()
4277
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
@@ -67,3 +102,62 @@ using AdvancedHMC: neg_energy, energy
67102
end
68103
end
69104
end
105+
106+
@testset "Multi variate Normal with Riemannian HMC" begin
107+
# Set the number of samples to draw and warmup iterations
108+
n_samples = 2_000
109+
rng = MersenneTwister(1110)
110+
initial_θ = rand(rng, D)
111+
λ = 1e-2
112+
_, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ)
113+
# Define a Hamiltonian system
114+
metric = DenseRiemannianMetric((D,), G, ∂G∂θ)
115+
kinetic = GaussianKinetic()
116+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
117+
118+
# Define a leapfrog solver, with the initial step size chosen heuristically
119+
initial_ϵ = 0.01
120+
integrator = GeneralizedLeapfrog(initial_ϵ, 6)
121+
122+
# Define an HMC sampler with the following components
123+
# - multinomial sampling scheme,
124+
# - generalised No-U-Turn criteria, and
125+
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8)))
126+
127+
# Run the sampler to draw samples from the specified Gaussian, where
128+
# - `samples` will store the samples
129+
# - `stats` will store diagnostic statistics for each sample
130+
samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true)
131+
@test length(samples) == n_samples
132+
@test length(stats) == n_samples
133+
end
134+
135+
@testset "Multi variate Normal with Riemannian HMC softabs metric" begin
136+
# Set the number of samples to draw and warmup iterations
137+
n_samples = 2_000
138+
rng = MersenneTwister(1110)
139+
initial_θ = rand(rng, D)
140+
λ = 1e-2
141+
_, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ)
142+
143+
# Define a Hamiltonian system
144+
metric = DenseRiemannianMetric((D,), G, ∂G∂θ, SoftAbsMap(20.0))
145+
kinetic = GaussianKinetic()
146+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
147+
148+
# Define a leapfrog solver, with the initial step size chosen heuristically
149+
initial_ϵ = 0.01
150+
integrator = GeneralizedLeapfrog(initial_ϵ, 6)
151+
152+
# Define an HMC sampler with the following components
153+
# - multinomial sampling scheme,
154+
# - generalised No-U-Turn criteria, and
155+
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8)))
156+
157+
# Run the sampler to draw samples from the specified Gaussian, where
158+
# - `samples` will store the samples
159+
# - `stats` will store diagnostic statistics for each sample
160+
samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true)
161+
@test length(samples) == n_samples
162+
@test length(stats) == n_samples
163+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ if GROUP == "All" || GROUP == "AdvancedHMC"
3131
include("abstractmcmc.jl")
3232
include("mcmcchains.jl")
3333
include("constructors.jl")
34+
include("riemannian.jl")
3435
retest(; dry=false, verbose=Inf)
3536
end
3637

0 commit comments

Comments
 (0)