@@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
9999"""
100100getlogjoint (vi:: AbstractVarInfo ) = getlogprior (vi) + getloglikelihood (vi)
101101
102+ """
103+ getlogjoint_internal(vi::AbstractVarInfo)
104+
105+ Return the log of the joint probability of the observed data and parameters as
106+ they are stored internally in `vi`, including the log-Jacobian for any linked
107+ parameters.
108+
109+ In general, we have that:
110+
111+ ```julia
112+ getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi)
113+ ```
114+ """
115+ getlogjoint_internal (vi:: AbstractVarInfo ) =
116+ getlogprior (vi) + getloglikelihood (vi) - getlogjac (vi)
117+
102118"""
103119 getlogp(vi::AbstractVarInfo)
104120
105- Return a NamedTuple of the log prior and log likelihood probabilities.
121+ Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities.
106122
107- The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
108- error will be thrown.
123+ The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them
124+ are not present in `vi` an error will be thrown.
109125"""
110126function getlogp (vi:: AbstractVarInfo )
111- return (; logprior= getlogprior (vi), loglikelihood= getloglikelihood (vi))
127+ return (;
128+ logprior= getlogprior (vi), logjac= getlogjac (vi), loglikelihood= getloglikelihood (vi)
129+ )
112130end
113131
114132"""
@@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@
164182"""
165183getlogprior (vi:: AbstractVarInfo ) = getacc (vi, Val (:LogPrior )). logp
166184
185+ """
186+ getlogprior_internal(vi::AbstractVarInfo)
187+
188+ Return the log of the prior probability of the parameters as stored internally
189+ in `vi`. This includes the log-Jacobian for any linked parameters.
190+
191+ In general, we have that:
192+
193+ ```julia
194+ getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi)
195+ ```
196+ """
197+ getlogprior_internal (vi:: AbstractVarInfo ) = getlogprior (vi) - getlogjac (vi)
198+
199+ """
200+ getlogjac(vi::AbstractVarInfo)
201+
202+ Return the accumulated log-Jacobian term for any linked parameters in `vi`. The
203+ Jacobian here is taken with respect to the forward (link) transform.
204+
205+ See also: [`setlogjac!!`](@ref).
206+ """
207+ getlogjac (vi:: AbstractVarInfo ) = getacc (vi, Val (:LogJacobian )). logJ
208+
167209"""
168210 getloglikelihood(vi::AbstractVarInfo)
169211
@@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
196238"""
197239setlogprior!! (vi:: AbstractVarInfo , logp) = setacc!! (vi, LogPriorAccumulator (logp))
198240
241+ """
242+ setlogjac!!(vi::AbstractVarInfo, logJ)
243+
244+ Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
245+ Jacobian here is taken with respect to the forward (link) transform.
246+
247+ See also: [`getlogjac!!`](@ref).
248+ """
249+ setlogjac!! (vi:: AbstractVarInfo , logJ) = setacc!! (vi, LogJacobianAccumulator (logJ))
250+
199251"""
200252 setloglikelihood!!(vi::AbstractVarInfo, logp)
201253
@@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`.
215267See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
216268"""
217269function setlogp!! (vi:: AbstractVarInfo , logp:: NamedTuple{names} ) where {names}
218- if ! (names == (:logprior , :loglikelihood ) || names == (:loglikelihood , :logprior ))
219- error (" logp must have the fields logprior and loglikelihood and no other fields." )
270+ if Set (names) != Set ([:logprior , :logjac , :loglikelihood ])
271+ error (
272+ " The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood." ,
273+ )
220274 end
221275 vi = setlogprior!! (vi, logp. logprior)
276+ vi = setlogjac!! (vi, logp. logjac)
222277 vi = setloglikelihood!! (vi, logp. loglikelihood)
223278 return vi
224279end
225280
226281function setlogp!! (vi:: AbstractVarInfo , logp:: Number )
227282 return error ("""
228283 `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
229- `setloglikelihood!!` and/or `setlogprior!!` instead.
284+ `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead.
230285 """ )
231286end
232287
@@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp)
306361 return map_accumulator!! (acc -> acc + LogPriorAccumulator (logp), vi, Val (:LogPrior ))
307362end
308363
364+ """
365+ acclogjac!!(vi::AbstractVarInfo, logJ)
366+
367+ Add `logJ` to the value of the log Jacobian in `vi`.
368+
369+ See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
370+ """
371+ function acclogjac!! (vi:: AbstractVarInfo , logJ)
372+ return map_accumulator!! (
373+ acc -> acc + LogJacobianAccumulator (logJ), vi, Val (:LogJacobian )
374+ )
375+ end
376+
309377"""
310378 accloglikelihood!!(vi::AbstractVarInfo, logp)
311379
@@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo)
368436 if hasacc (vi, Val (:LogPrior ))
369437 vi = map_accumulator!! (zero, vi, Val (:LogPrior ))
370438 end
439+ if hasacc (vi, Val (:LogJacobian ))
440+ vi = map_accumulator!! (zero, vi, Val (:LogJacobian ))
441+ end
371442 if hasacc (vi, Val (:LogLikelihood ))
372443 vi = map_accumulator!! (zero, vi, Val (:LogLikelihood ))
373444 end
@@ -836,8 +907,10 @@ function link!!(
836907 x = vi[:]
837908 y, logjac = with_logabsdet_jacobian (b, x)
838909
839- lp_new = getlogprior (vi) - logjac
840- vi_new = setlogprior!! (unflatten (vi, y), lp_new)
910+ # Set parameters
911+ vi_new = unflatten (vi, y)
912+ # Update logjac
913+ vi_new = setlogjac!! (vi_new, logjac)
841914 return settrans!! (vi_new, t)
842915end
843916
@@ -846,10 +919,12 @@ function invlink!!(
846919)
847920 b = t. bijector
848921 y = vi[:]
849- x, logjac = with_logabsdet_jacobian (b, y)
922+ x = b ( y)
850923
851- lp_new = getlogprior (vi) + logjac
852- vi_new = setlogprior!! (unflatten (vi, x), lp_new)
924+ # Set parameters
925+ vi_new = unflatten (vi, x)
926+ # Reset logjac to 0
927+ vi_new = setlogjac!! (vi_new, 0.0 )
853928 return settrans!! (vi_new, NoTransformation ())
854929end
855930
0 commit comments