1- using ReTest, AdvancedHMC
2-
3- include (" ../src/riemannian_hmc.jl" )
4- include (" ../src/riemannian_hmc_utility.jl" )
5-
1+ using ReTest, Random
2+ using AdvancedHMC, ForwardDiff, AbstractMCMC
3+ using LinearAlgebra
4+ using MCMCLogDensityProblems
65using FiniteDiff:
76 finite_difference_gradient, finite_difference_hessian, finite_difference_jacobian
8- using Distributions: MvNormal
9- using AdvancedHMC: neg_energy, energy
7+ using AdvancedHMC: neg_energy, energy, ∂H∂θ, ∂H∂r
8+
9+ # Fisher information metric
10+ function gen_∂G∂θ_fwd (Vfunc, x; f= identity)
11+ _Hfunc = gen_hess_fwd (Vfunc, x)
12+ Hfunc = x -> _Hfunc (x)[3 ]
13+ # QUES What's the best output format of this function?
14+ cfg = ForwardDiff. JacobianConfig (Hfunc, x)
15+ d = length (x)
16+ out = zeros (eltype (x), d^ 2 , d)
17+ return x -> ForwardDiff. jacobian! (out, Hfunc, x, cfg)
18+ return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
19+ end
20+
21+ function gen_hess_fwd (func, x:: AbstractVector )
22+ function hess (x:: AbstractVector )
23+ return nothing , nothing , ForwardDiff. hessian (func, x)
24+ end
25+ return hess
26+ end
27+
28+ function reshape_∂G∂θ (H)
29+ d = size (H, 2 )
30+ return cat ((H[((i - 1 ) * d + 1 ): (i * d), :] for i in 1 : d). .. ; dims= 3 )
31+ end
1032
11- # Taken from https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/test/finitedifftests.jl
12- δ (a, b) = maximum (abs .(a - b))
33+ function prepare_sample (ℓπ, initial_θ, λ)
34+ Vfunc = x -> - ℓπ (x)
35+ _Hfunc = MCMCLogDensityProblems. gen_hess (Vfunc, initial_θ) # x -> (value, gradient, hessian)
36+ Hfunc = x -> copy .(_Hfunc (x)) # _Hfunc do in-place computation, copy to avoid bug
1337
14- @testset " Riemannian" begin
15- hps = (; λ= 1e-2 , α= 20.0 , ϵ= 0.1 , n= 6 , L= 8 )
38+ fstabilize = H -> H + λ * I
39+ Gfunc = x -> begin
40+ H = fstabilize (Hfunc (x)[3 ])
41+ all (isfinite, H) ? H : diagm (ones (length (x)))
42+ end
43+ _∂G∂θfunc = gen_∂G∂θ_fwd (x -> - ℓπ (x), initial_θ; f= fstabilize)
44+ ∂G∂θfunc = x -> reshape_∂G∂θ (_∂G∂θfunc (x))
45+
46+ return Vfunc, Hfunc, Gfunc, ∂G∂θfunc
47+ end
1648
49+ @testset " Constructors tests" begin
50+ δ (a, b) = maximum (abs .(a - b))
1751 @testset " $(nameof (typeof (target))) " for target in [HighDimGaussian (2 ), Funnel ()]
1852 rng = MersenneTwister (1110 )
53+ λ = 1e-2
1954
2055 θ₀ = rand (rng, dim (target))
2156
2257 ℓπ = MCMCLogDensityProblems. gen_logpdf (target)
2358 ∂ℓπ∂θ = MCMCLogDensityProblems. gen_logpdf_grad (target, θ₀)
2459
25- Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target (hps , θ₀, ℓπ )
60+ Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample (ℓπ , θ₀, λ )
2661
2762 D = dim (target) # ==2 for this test
2863 x = zeros (D) # randn(rng, D)
@@ -36,7 +71,7 @@ using AdvancedHMC: neg_energy, energy
3671 end
3772
3873 @testset " $(nameof (typeof (hessmap))) " for hessmap in
39- [IdentityMap (), SoftAbsMap (hps . α )]
74+ [IdentityMap (), SoftAbsMap (20.0 )]
4075 metric = DenseRiemannianMetric ((D,), Gfunc, ∂G∂θfunc, hessmap)
4176 kinetic = GaussianKinetic ()
4277 hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∂ℓπ∂θ)
@@ -67,3 +102,62 @@ using AdvancedHMC: neg_energy, energy
67102 end
68103 end
69104end
105+
106+ @testset " Multi variate Normal with Riemannian HMC" begin
107+ # Set the number of samples to draw and warmup iterations
108+ n_samples = 2_000
109+ rng = MersenneTwister (1110 )
110+ initial_θ = rand (rng, D)
111+ λ = 1e-2
112+ _, _, G, ∂G∂θ = prepare_sample (ℓπ, initial_θ, λ)
113+ # Define a Hamiltonian system
114+ metric = DenseRiemannianMetric ((D,), G, ∂G∂θ)
115+ kinetic = GaussianKinetic ()
116+ hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∂ℓπ∂θ)
117+
118+ # Define a leapfrog solver, with the initial step size chosen heuristically
119+ initial_ϵ = 0.01
120+ integrator = GeneralizedLeapfrog (initial_ϵ, 6 )
121+
122+ # Define an HMC sampler with the following components
123+ # - multinomial sampling scheme,
124+ # - generalised No-U-Turn criteria, and
125+ kernel = HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (8 )))
126+
127+ # Run the sampler to draw samples from the specified Gaussian, where
128+ # - `samples` will store the samples
129+ # - `stats` will store diagnostic statistics for each sample
130+ samples, stats = sample (rng, hamiltonian, kernel, initial_θ, n_samples; progress= true )
131+ @test length (samples) == n_samples
132+ @test length (stats) == n_samples
133+ end
134+
135+ @testset " Multi variate Normal with Riemannian HMC softabs metric" begin
136+ # Set the number of samples to draw and warmup iterations
137+ n_samples = 2_000
138+ rng = MersenneTwister (1110 )
139+ initial_θ = rand (rng, D)
140+ λ = 1e-2
141+ _, _, G, ∂G∂θ = prepare_sample (ℓπ, initial_θ, λ)
142+
143+ # Define a Hamiltonian system
144+ metric = DenseRiemannianMetric ((D,), G, ∂G∂θ, SoftAbsMap (20.0 ))
145+ kinetic = GaussianKinetic ()
146+ hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∂ℓπ∂θ)
147+
148+ # Define a leapfrog solver, with the initial step size chosen heuristically
149+ initial_ϵ = 0.01
150+ integrator = GeneralizedLeapfrog (initial_ϵ, 6 )
151+
152+ # Define an HMC sampler with the following components
153+ # - multinomial sampling scheme,
154+ # - generalised No-U-Turn criteria, and
155+ kernel = HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (8 )))
156+
157+ # Run the sampler to draw samples from the specified Gaussian, where
158+ # - `samples` will store the samples
159+ # - `stats` will store diagnostic statistics for each sample
160+ samples, stats = sample (rng, hamiltonian, kernel, initial_θ, n_samples; progress= true )
161+ @test length (samples) == n_samples
162+ @test length (stats) == n_samples
163+ end
0 commit comments