@@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818"""
1919 LogDensityFunction(
2020 model::Model,
21- getlogdensity::Function=getlogjoint ,
21+ getlogdensity::Function=getlogjoint_internal ,
2222 varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
2323 adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2424 )
@@ -29,21 +29,44 @@ A struct which contains a model, along with all the information necessary to:
2929 - and if `adtype` is provided, calculate the gradient of the log density at
3030 that point.
3131
32- At its most basic level, a LogDensityFunction wraps the model together with a
33- function that specifies how to extract the log density, and the type of
34- VarInfo to be used. These must be known in order to calculate the log density
35- (using [`DynamicPPL.evaluate!!`](@ref)).
32+ This information can be extracted using the LogDensityProblems.jl interface,
33+ specifically, using `LogDensityProblems.logdensity` and
34+ `LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only
35+ `logdensity` is implemented. If `adtype` is a concrete AD backend type, then
36+ `logdensity_and_gradient` is also implemented.
37+
38+ There are several options for `getlogdensity` that are 'supported' out of the
39+ box:
40+
41+ - [`getlogjoint_internal`](@ref): calculate the log joint, including the
42+ log-Jacobian term for any variables that have been linked in the provided
43+ VarInfo.
44+ - [`getlogprior_internal`](@ref): calculate the log prior, including the
45+ log-Jacobian term for any variables that have been linked in the provided
46+ VarInfo.
47+ - [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring
48+ any effects of linking
49+ - [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
50+ any effects of linking
51+ - [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
52+ by linking, since transforms are only applied to random variables)
53+
54+ !!! note
55+ By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
56+ result of `LogDensityProblems.logdensity(f, x)` will depend on whether the
57+ `LogDensityFunction` was created with a linked or unlinked VarInfo. This
58+ is done primarily to ease interoperability with MCMC samplers.
59+
60+ If you provide one of these functions, a `VarInfo` will be automatically created
61+ for you. If you provide a different function, you have to manually create a
62+ VarInfo and pass it as the third argument.
3663
3764If the `adtype` keyword argument is provided, then this struct will also store
3865the adtype along with other information for efficient calculation of the
3966gradient of the log density. Note that preparing a `LogDensityFunction` with an
4067AD type `AutoBackend()` requires the AD backend itself to have been loaded
4168(e.g. with `import Backend`).
4269
43- `DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
44- If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
45- concrete AD backend type, then `logdensity_and_gradient` is also implemented.
46-
4770# Fields
4871$(FIELDS)
4972
@@ -74,7 +97,7 @@ julia> LogDensityProblems.dimension(f)
74971
7598
7699julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
77- f = LogDensityFunction(model, getlogjoint , SimpleVarInfo(model));
100+ f = LogDensityFunction(model, getlogjoint_internal , SimpleVarInfo(model));
78101
79102julia> LogDensityProblems.logdensity(f, [0.0])
80103-2.3378770664093453
@@ -99,7 +122,7 @@ struct LogDensityFunction{
99122} <: AbstractModel
100123 " model used for evaluation"
101124 model:: M
102- " function to be called on `varinfo` to extract the log density. By default `getlogjoint `."
125+ " function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal `."
103126 getlogdensity:: F
104127 " varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
105128 varinfo:: V
@@ -110,7 +133,7 @@ struct LogDensityFunction{
110133
111134 function LogDensityFunction (
112135 model:: Model ,
113- getlogdensity:: Function = getlogjoint ,
136+ getlogdensity:: Function = getlogjoint_internal ,
114137 varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
115138 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
116139 )
@@ -180,10 +203,18 @@ function ldf_default_varinfo(::Model, getlogdensity::Function)
180203 return error (msg)
181204end
182205
183- ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) = VarInfo (model)
206+ ldf_default_varinfo (model:: Model , :: typeof (getlogjoint_internal)) = VarInfo (model)
207+
208+ function ldf_default_varinfo (model:: Model , :: typeof (getlogjoint))
209+ return setaccs!! (VarInfo (model), (LogPriorAccumulator (), LogLikelihoodAccumulator ()))
210+ end
211+
212+ function ldf_default_varinfo (model:: Model , :: typeof (getlogprior_internal))
213+ return setaccs!! (VarInfo (model), (LogPriorAccumulator (), LogJacobianAccumulator ()))
214+ end
184215
185216function ldf_default_varinfo (model:: Model , :: typeof (getlogprior))
186- return setaccs!! (VarInfo (model), (LogPriorAccumulator (), ))
217+ return setaccs!! (VarInfo (model), (LogPriorAccumulator ()))
187218end
188219
189220function ldf_default_varinfo (model:: Model , :: typeof (getloglikelihood))
0 commit comments