Skip to content

Commit e1f3f5b

Browse files
committed
Make ∂H∂r consistent with ∂H∂θ
1 parent 6429519 commit e1f3f5b

File tree

4 files changed

+23
-38
lines changed

4 files changed

+23
-38
lines changed

src/hamiltonian.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,12 @@ function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::A
6767
return M⁻¹ * r
6868
end
6969

70-
# TODO (kai) make the order of θ and r consistent with neg_energy
7170
# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic
7271
# The gradient of a position-dependent Hamiltonian system depends on both θ and r.
7372
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
74-
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)
73+
function ∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat)
74+
return DualValue(neg_energy(h, θ, r), ∂H∂r(h, r))
75+
end
7576

7677
struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
7778
θ::T # Position variables / model parameters.
@@ -101,7 +102,7 @@ function Base.similar(z::PhasePoint{<:AbstractVecOrMat{T}}) where {T<:AbstractFl
101102
end
102103

103104
function phasepoint(
104-
h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r))
105+
h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=∂H∂r(h, θ, r)
105106
) where {T<:AbstractVecOrMat}
106107
return PhasePoint(θ, r, ℓπ, ℓκ)
107108
end
@@ -110,12 +111,7 @@ end
110111
# move the momentum variable to that of the position variable.
111112
# This is needed for AHMC to work with CuArrays and other Arrays (without depending on it).
112113
function phasepoint(
113-
h::Hamiltonian,
114-
θ::T1,
115-
_r::T2;
116-
r=safe_rsimilar(θ, _r),
117-
ℓπ=∂H∂θ(h, θ),
118-
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)),
114+
h::Hamiltonian, θ::T1, _r::T2; r=safe_rsimilar(θ, _r), ℓπ=∂H∂θ(h, θ), ℓκ=∂H∂r(h, θ, r)
119115
) where {T1<:AbstractVecOrMat,T2<:AbstractVecOrMat}
120116
return PhasePoint(θ, r, ℓπ, ℓκ)
121117
end
@@ -141,31 +137,31 @@ neg_energy(h::Hamiltonian, θ::AbstractVecOrMat) = h.ℓπ(θ)
141137
# GaussianKinetic
142138

143139
function neg_energy(
144-
h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
140+
h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, θ::T, r::T
145141
) where {T<:AbstractVector}
146142
return -sum(abs2, r) / 2
147143
end
148144

149145
function neg_energy(
150-
h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
146+
h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, θ::T, r::T
151147
) where {T<:AbstractMatrix}
152148
return -vec(sum(abs2, r; dims=1)) / 2
153149
end
154150

155151
function neg_energy(
156-
h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
152+
h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, θ::T, r::T
157153
) where {T<:AbstractVector}
158154
return -sum(abs2.(r) .* h.metric.M⁻¹) / 2
159155
end
160156

161157
function neg_energy(
162-
h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
158+
h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, θ::T, r::T
163159
) where {T<:AbstractMatrix}
164160
return -vec(sum(abs2.(r) .* h.metric.M⁻¹; dims=1)) / 2
165161
end
166162

167163
function neg_energy(
168-
h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
164+
h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, θ::T, r::T
169165
) where {T<:AbstractVecOrMat}
170166
mul!(h.metric._temp, h.metric.M⁻¹, r)
171167
return -dot(r, h.metric._temp) / 2

src/riemannian/hamiltonian.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ function step(
7272
end
7373
#! Eq (17) of Girolami & Calderhead (2011)
7474
θ_full = copy(θ_init)
75-
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
75+
term_1 = ∂H∂r(h, θ_init, r_half).gradient # unchanged across the loop
7676
for j in 1:(lf.n)
77-
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
77+
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half).gradient)
7878
# println("θ_full :", θ_full)
7979
end
8080
#! Eq (18) of Girolami & Calderhead (2011)
@@ -103,7 +103,6 @@ function step(
103103
return res
104104
end
105105

106-
# TODO Make the order of θ and r consistent with neg_energy
107106
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
108107
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)
109108

@@ -227,19 +226,15 @@ using LinearAlgebra: logabsdet, tr
227226
# QUES Do we want to change everything to position dependent by default?
228227
# Add θ to ∂H∂r for DenseRiemannianMetric
229228
function phasepoint(
230-
h::Hamiltonian{<:DenseRiemannianMetric},
231-
θ::T,
232-
r::T;
233-
ℓπ=∂H∂θ(h, θ),
234-
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
229+
h::Hamiltonian{<:DenseRiemannianMetric}, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=∂H∂r(h, θ, r)
235230
) where {T<:AbstractVecOrMat}
236231
return PhasePoint(θ, r, ℓπ, ℓκ)
237232
end
238233

239234
# Negative kinetic energy
240235
#! Eq (13) of Girolami & Calderhead (2011)
241236
function neg_energy(
242-
h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T
237+
h::Hamiltonian{<:DenseRiemannianMetric}, θ::T, r::T
243238
) where {T<:AbstractVecOrMat}
244239
G = h.metric.map(h.metric.G(θ))
245240
D = size(G, 1)
@@ -347,12 +342,6 @@ function ∂H∂r(
347342
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
348343
)
349344
H = h.metric.G(θ)
350-
# if !all(isfinite, H)
351-
# println("θ: ", θ)
352-
# println("H: ", H)
353-
# end
354345
G = h.metric.map(H)
355-
# return inv(G) * r
356-
# println("G \ r: ", G \ r)
357-
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
346+
return DualValue(neg_energy(h, θ, r), G \ r)
358347
end

src/riemannian/integrator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ function step(
7474
end
7575
# eq (17) of Girolami & Calderhead (2011)
7676
θ_full = θ_init
77-
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
77+
term_1 = ∂H∂r(h, θ_init, r_half).gradient # unchanged across the loop
7878
for j in 1:(lf.n)
79-
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
79+
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half).gradient)
8080
end
8181
# eq (18) of Girolami & Calderhead (2011)
8282
(; value, gradient) = ∂H∂θ(h, θ_full, r_half)

test/hamiltonian.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ end
6060
r_init = randn(T, D)
6161

6262
h = Hamiltonian(UnitEuclideanMetric(T, D), ℓπ, ∂ℓπ∂θ)
63-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) == sum(abs2, r_init) / 2
63+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init) == sum(abs2, r_init) / 2
6464
@test AdvancedHMC.∂H∂r(h, r_init) == r_init
6565

6666
M⁻¹ = ones(T, D) + abs.(randn(T, D))
6767
h = Hamiltonian(DiagEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
68-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init)
68+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init)
6969
r_init' * diagm(0 => M⁻¹) * r_init / 2
7070
@test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ .* r_init
7171

7272
m = randn(T, D, D)
7373
M⁻¹ = m' * m
7474
h = Hamiltonian(DenseEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
75-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) r_init' * M⁻¹ * r_init / 2
75+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init) r_init' * M⁻¹ * r_init / 2
7676
@test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ * r_init
7777
end
7878
end
@@ -86,15 +86,15 @@ end
8686
r_init = ComponentArray(; a=randn(T, D), b=randn(T, D))
8787

8888
h = Hamiltonian(UnitEuclideanMetric(T, 2 * D), ℓπ, ∂ℓπ∂θ)
89-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) == sum(abs2, r_init) / 2
89+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init) == sum(abs2, r_init) / 2
9090
@test AdvancedHMC.∂H∂r(h, r_init) == r_init
9191
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
9292

9393
M⁻¹ = ComponentArray(;
9494
a=ones(T, D) + abs.(randn(T, D)), b=ones(T, D) + abs.(randn(T, D))
9595
)
9696
h = Hamiltonian(DiagEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
97-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init)
97+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init)
9898
r_init' * diagm(0 => M⁻¹) * r_init / 2
9999
@test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ .* r_init
100100
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
@@ -103,7 +103,7 @@ end
103103
ax = getaxes(r_init)[1]
104104
M⁻¹ = ComponentArray(m' * m, ax, ax)
105105
h = Hamiltonian(DenseEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
106-
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) r_init' * M⁻¹ * r_init / 2
106+
@test -AdvancedHMC.neg_energy(h, θ_init, r_init) r_init' * M⁻¹ * r_init / 2
107107
@test all(AdvancedHMC.∂H∂r(h, r_init) .== M⁻¹ * r_init)
108108
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
109109
end

0 commit comments

Comments
 (0)