Skip to content

Commit a29b953

Browse files
committed
logjac accumulator
1 parent e60eab0 commit a29b953

File tree

7 files changed

+195
-41
lines changed

7 files changed

+195
-41
lines changed

src/abstract_varinfo.jl

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
9999
"""
100100
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)
101101

102+
"""
103+
getlogjoint_internal(vi::AbstractVarInfo)
104+
105+
Return the log of the joint probability of the observed data and parameters as
106+
they are stored internally in `vi`, including the log-Jacobian for any linked
107+
parameters.
108+
109+
In general, we have that:
110+
111+
```julia
112+
getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi)
113+
```
114+
"""
115+
getlogjoint_internal(vi::AbstractVarInfo) =
116+
getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi)
117+
102118
"""
103119
getlogp(vi::AbstractVarInfo)
104120
105-
Return a NamedTuple of the log prior and log likelihood probabilities.
121+
Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities.
106122
107-
The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
108-
error will be thrown.
123+
The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them
124+
are not present in `vi` an error will be thrown.
109125
"""
110126
function getlogp(vi::AbstractVarInfo)
111-
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
127+
return (;
128+
logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi)
129+
)
112130
end
113131

114132
"""
@@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@
164182
"""
165183
getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp
166184

185+
"""
186+
getlogprior_internal(vi::AbstractVarInfo)
187+
188+
Return the log of the prior probability of the parameters as stored internally
189+
in `vi`. This includes the log-Jacobian for any linked parameters.
190+
191+
In general, we have that:
192+
193+
```julia
194+
getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi)
195+
```
196+
"""
197+
getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi)
198+
199+
"""
200+
getlogjac(vi::AbstractVarInfo)
201+
202+
Return the accumulated log-Jacobian term for any linked parameters in `vi`. The
203+
Jacobian here is taken with respect to the forward (link) transform.
204+
205+
See also: [`setlogjac!!`](@ref).
206+
"""
207+
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ
208+
167209
"""
168210
getloglikelihood(vi::AbstractVarInfo)
169211
@@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
196238
"""
197239
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))
198240

241+
"""
242+
setlogjac!!(vi::AbstractVarInfo, logJ)
243+
244+
Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
245+
Jacobian here is taken with respect to the forward (link) transform.
246+
247+
See also: [`getlogjac!!`](@ref).
248+
"""
249+
setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ))
250+
199251
"""
200252
setloglikelihood!!(vi::AbstractVarInfo, logp)
201253
@@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`.
215267
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
216268
"""
217269
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
218-
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
219-
error("logp must have the fields logprior and loglikelihood and no other fields.")
270+
if Set(names) != Set([:logprior, :logjac, :loglikelihood])
271+
error(
272+
"The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.",
273+
)
220274
end
221275
vi = setlogprior!!(vi, logp.logprior)
276+
vi = setlogjac!!(vi, logp.logjac)
222277
vi = setloglikelihood!!(vi, logp.loglikelihood)
223278
return vi
224279
end
225280

226281
function setlogp!!(vi::AbstractVarInfo, logp::Number)
227282
return error("""
228283
`setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
229-
`setloglikelihood!!` and/or `setlogprior!!` instead.
284+
`setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead.
230285
""")
231286
end
232287

@@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp)
306361
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
307362
end
308363

364+
"""
365+
acclogjac!!(vi::AbstractVarInfo, logJ)
366+
367+
Add `logJ` to the value of the log Jacobian in `vi`.
368+
369+
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
370+
"""
371+
function acclogjac!!(vi::AbstractVarInfo, logJ)
372+
return map_accumulator!!(
373+
acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian)
374+
)
375+
end
376+
309377
"""
310378
accloglikelihood!!(vi::AbstractVarInfo, logp)
311379
@@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo)
368436
if hasacc(vi, Val(:LogPrior))
369437
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
370438
end
439+
if hasacc(vi, Val(:LogJacobian))
440+
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
441+
end
371442
if hasacc(vi, Val(:LogLikelihood))
372443
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
373444
end
@@ -836,8 +907,10 @@ function link!!(
836907
x = vi[:]
837908
y, logjac = with_logabsdet_jacobian(b, x)
838909

839-
lp_new = getlogprior(vi) - logjac
840-
vi_new = setlogprior!!(unflatten(vi, y), lp_new)
910+
# Set parameters
911+
vi_new = unflatten(vi, y)
912+
# Update logjac
913+
vi_new = setlogjac!!(vi_new, logjac)
841914
return settrans!!(vi_new, t)
842915
end
843916

@@ -846,10 +919,12 @@ function invlink!!(
846919
)
847920
b = t.bijector
848921
y = vi[:]
849-
x, logjac = with_logabsdet_jacobian(b, y)
922+
x = b(y)
850923

851-
lp_new = getlogprior(vi) + logjac
852-
vi_new = setlogprior!!(unflatten(vi, x), lp_new)
924+
# Set parameters
925+
vi_new = unflatten(vi, x)
926+
# Reset logjac to 0
927+
vi_new = setlogjac!!(vi_new, 0.0)
853928
return settrans!!(vi_new, NoTransformation())
854929
end
855930

src/accumulators.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,21 @@ seen so far.
1111
1212
An accumulator type `T <: AbstractAccumulator` must implement the following methods:
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
14-
- `accumulate_observe!!(acc::T, right, left, vn)`
15-
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
14+
- `accumulate_observe!!(acc::T, dist, val, vn)`
15+
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
1616
- `Base.copy(acc::T)`
1717
18+
In these functions:
19+
- `val` is the new value of the random variable sampled from a new distribution (always
20+
in the original unlinked space), or the value on the left-hand side of an observe
21+
statement.
22+
- `dist` is the distribution on the RHS of the tilde statement.
23+
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the
24+
tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`.
25+
- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the
26+
variable is stored as a linked value in the VarInfo. If the variable is stored in its
27+
original, unlinked form, then `logjac` is zero.
28+
1829
To be able to work with multi-threading, it should also implement:
1930
- `split(acc::T)`
2031
- `combine(acc::T, acc2::T)`

src/context_implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function assume(dist::Distribution, vn::VarName, vi)
124124
y = getindex_internal(vi, vn)
125125
f = from_maybe_linked_internal_transform(vi, vn, dist)
126126
x, logjac = with_logabsdet_jacobian(f, y)
127-
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
127+
vi = accumulate_assume!!(vi, x, -logjac, vn, dist)
128128
return x, vi
129129
end
130130

@@ -166,6 +166,6 @@ function assume(
166166

167167
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
168168
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
169-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
169+
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
170170
return r, vi
171171
end

0 commit comments

Comments
 (0)