Skip to content

Commit d6c9cfa

Browse files
committed
Fix tests
1 parent a29b953 commit d6c9cfa

File tree

13 files changed

+142
-89
lines changed

13 files changed

+142
-89
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8686
vi = DynamicPPL.link(vi, model)
8787
end
8888

89-
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(
90+
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
91+
)
9092
# The parameters at which we evaluate f.
9193
θ = vi[:]
9294

docs/src/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators.
367367

368368
```@docs
369369
LogPriorAccumulator
370+
LogJacobianAccumulator
370371
LogLikelihoodAccumulator
371372
VariableOrderAccumulator
372373
```
@@ -380,7 +381,11 @@ getlogp
380381
setlogp!!
381382
acclogp!!
382383
getlogjoint
384+
getlogjoint_internal
385+
getlogjac
386+
setlogjac!!
383387
getlogprior
388+
getlogprior_internal
384389
setlogprior!!
385390
acclogprior!!
386391
getloglikelihood

src/DynamicPPL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ export AbstractVarInfo,
5858
getlogjoint,
5959
getlogprior,
6060
getloglikelihood,
61+
getlogjac,
62+
getlogjoint_internal,
63+
getlogprior_internal,
6164
setlogp!!,
6265
setlogprior!!,
6366
setloglikelihood!!,

src/logdensityfunction.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818
"""
1919
LogDensityFunction(
2020
model::Model,
21-
getlogdensity::Function=getlogjoint,
21+
getlogdensity::Function=getlogjoint_internal,
2222
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
2323
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2424
)
@@ -29,21 +29,44 @@ A struct which contains a model, along with all the information necessary to:
2929
- and if `adtype` is provided, calculate the gradient of the log density at
3030
that point.
3131
32-
At its most basic level, a LogDensityFunction wraps the model together with a
33-
function that specifies how to extract the log density, and the type of
34-
VarInfo to be used. These must be known in order to calculate the log density
35-
(using [`DynamicPPL.evaluate!!`](@ref)).
32+
This information can be extracted using the LogDensityProblems.jl interface,
33+
specifically, using `LogDensityProblems.logdensity` and
34+
`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only
35+
`logdensity` is implemented. If `adtype` is a concrete AD backend type, then
36+
`logdensity_and_gradient` is also implemented.
37+
38+
There are several options for `getlogdensity` that are 'supported' out of the
39+
box:
40+
41+
- [`getlogjoint_internal`](@ref): calculate the log joint, including the
42+
log-Jacobian term for any variables that have been linked in the provided
43+
VarInfo.
44+
- [`getlogprior_internal`](@ref): calculate the log prior, including the
45+
log-Jacobian term for any variables that have been linked in the provided
46+
VarInfo.
47+
- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring
48+
any effects of linking
49+
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
50+
any effects of linking
51+
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
52+
by linking, since transforms are only applied to random variables)
53+
54+
!!! note
55+
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
56+
result of `LogDensityProblems.logdensity(f, x)` will depend on whether the
57+
`LogDensityFunction` was created with a linked or unlinked VarInfo. This
58+
is done primarily to ease interoperability with MCMC samplers.
59+
60+
If you provide one of these functions, a `VarInfo` will be automatically created
61+
for you. If you provide a different function, you have to manually create a
62+
VarInfo and pass it as the third argument.
3663
3764
If the `adtype` keyword argument is provided, then this struct will also store
3865
the adtype along with other information for efficient calculation of the
3966
gradient of the log density. Note that preparing a `LogDensityFunction` with an
4067
AD type `AutoBackend()` requires the AD backend itself to have been loaded
4168
(e.g. with `import Backend`).
4269
43-
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
44-
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
45-
concrete AD backend type, then `logdensity_and_gradient` is also implemented.
46-
4770
# Fields
4871
$(FIELDS)
4972
@@ -74,7 +97,7 @@ julia> LogDensityProblems.dimension(f)
7497
1
7598
7699
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
77-
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
100+
f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model));
78101
79102
julia> LogDensityProblems.logdensity(f, [0.0])
80103
-2.3378770664093453
@@ -99,7 +122,7 @@ struct LogDensityFunction{
99122
} <: AbstractModel
100123
"model used for evaluation"
101124
model::M
102-
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
125+
"function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`."
103126
getlogdensity::F
104127
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
105128
varinfo::V
@@ -110,7 +133,7 @@ struct LogDensityFunction{
110133

111134
function LogDensityFunction(
112135
model::Model,
113-
getlogdensity::Function=getlogjoint,
136+
getlogdensity::Function=getlogjoint_internal,
114137
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
115138
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
116139
)
@@ -180,10 +203,18 @@ function ldf_default_varinfo(::Model, getlogdensity::Function)
180203
return error(msg)
181204
end
182205

183-
ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)
206+
ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model)
207+
208+
function ldf_default_varinfo(model::Model, ::typeof(getlogjoint))
209+
return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator()))
210+
end
211+
212+
function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal))
213+
return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator()))
214+
end
184215

185216
function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
186-
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
217+
return setaccs!!(VarInfo(model), (LogPriorAccumulator()))
187218
end
188219

189220
function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))

src/model.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,10 @@ Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model)
995995
996996
Return the log joint probability of variables `varinfo` for the probabilistic `model`.
997997
998+
Note that this probability always refers to the parameters in unlinked space, i.e.,
999+
the return value of `logjoint` does not depend on whether `VarInfo` has been linked
1000+
or not.
1001+
9981002
See [`logprior`](@ref) and [`loglikelihood`](@ref).
9991003
"""
10001004
function logjoint(model::Model, varinfo::AbstractVarInfo)
@@ -1042,6 +1046,10 @@ end
10421046
10431047
Return the log prior probability of variables `varinfo` for the probabilistic `model`.
10441048
1049+
Note that this probability always refers to the parameters in unlinked space, i.e.,
1050+
the return value of `logprior` does not depend on whether `VarInfo` has been linked
1051+
or not.
1052+
10451053
See also [`logjoint`](@ref) and [`loglikelihood`](@ref).
10461054
"""
10471055
function logprior(model::Model, varinfo::AbstractVarInfo)

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
125125
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))
126126
127127
julia> # (✓) Positive probability mass on negative numbers!
128-
getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
128+
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
129129
-1.3678794411714423
130130
131131
julia> # While if we forget to indicate that it's transformed:
132132
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
133133
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))
134134
135135
julia> # (✓) No probability mass on negative numbers!
136-
getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
136+
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
137137
-Inf
138138
```
139139

src/test_utils/ad.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
7+
using DynamicPPL:
8+
Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link
89
using LogDensityProblems: logdensity, logdensity_and_gradient
910
using Random: AbstractRNG, default_rng
1011
using Statistics: median
@@ -224,7 +225,7 @@ function run_ad(
224225
benchmark::Bool=false,
225226
atol::AbstractFloat=100 * eps(),
226227
rtol::AbstractFloat=sqrt(eps()),
227-
getlogdensity::Function=getlogjoint,
228+
getlogdensity::Function=getlogjoint_internal,
228229
rng::AbstractRNG=default_rng(),
229230
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
230231
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,

test/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
3030

3131
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
3232
linked_varinfo = DynamicPPL.link(varinfo, m)
33-
f = LogDensityFunction(m, getlogjoint, linked_varinfo)
33+
f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo)
3434
x = DynamicPPL.getparams(f)
3535

3636
# Calculate reference logp + gradient of logp using ForwardDiff
@@ -52,17 +52,17 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
5252
if is_mooncake && is_1_11 && is_svi_vnv
5353
# https://github.com/compintell/Mooncake.jl/issues/470
5454
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
55-
m, getlogjoint, linked_varinfo; adtype=adtype
55+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
5656
)
5757
elseif is_mooncake && is_1_10 && is_svi_vnv
5858
# TODO: report upstream
5959
@test_throws UndefRefError DynamicPPL.LogDensityFunction(
60-
m, getlogjoint, linked_varinfo; adtype=adtype
60+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
6161
)
6262
elseif is_mooncake && is_1_10 && is_svi_od
6363
# TODO: report upstream
6464
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
65-
m, getlogjoint, linked_varinfo; adtype=adtype
65+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
6666
)
6767
else
6868
@test run_ad(
@@ -113,7 +113,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
113113
spl = Sampler(MyEmptyAlg())
114114
sampling_model = contextualize(model, SamplingContext(model.context))
115115
ldf = LogDensityFunction(
116-
sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true)
116+
sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true)
117117
)
118118
x = ldf.varinfo[:]
119119
@test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any

test/linking.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ end
8484
else
8585
DynamicPPL.link(vi, model)
8686
end
87-
# Difference should just be the log-absdet-jacobian "correction".
88-
@test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) log(2)
87+
# Difference between the internal logjoints should just be the log-absdet-jacobian "correction".
88+
@test DynamicPPL.getlogjoint_internal(vi) -
89+
DynamicPPL.getlogjoint_internal(vi_linked) log(2)
90+
# The non-internal logjoint should be the same.
91+
@test DynamicPPL.getlogjoint(vi) DynamicPPL.getlogjoint_internal(vi_linked)
8992
@test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist])
9093
# Linked one should be working with a lower-dimensional representation.
9194
@test length(vi_linked[:]) < length(vi[:])
@@ -99,6 +102,8 @@ end
99102
@test length(vi_invlinked[:]) == length(vi[:])
100103
@test vi_invlinked[@varname(m), dist] LowerTriangular(vi[@varname(m), dist])
101104
@test DynamicPPL.getlogjoint(vi_invlinked) DynamicPPL.getlogjoint(vi)
105+
@test DynamicPPL.getlogjoint_internal(vi_invlinked)
106+
DynamicPPL.getlogjoint_internal(vi)
102107
end
103108
end
104109

@@ -130,15 +135,15 @@ end
130135
end
131136
@test length(vi_linked[:]) == d * (d - 1) ÷ 2
132137
# Should now include the log-absdet-jacobian correction.
133-
@test !(getlogjoint(vi_linked) lp)
138+
@test !(getlogjoint_internal(vi_linked) lp)
134139
# Invlinked.
135140
vi_invlinked = if mutable
136141
DynamicPPL.invlink!!(deepcopy(vi_linked), model)
137142
else
138143
DynamicPPL.invlink(vi_linked, model)
139144
end
140145
@test length(vi_invlinked[:]) == d^2
141-
@test getlogjoint(vi_invlinked) lp
146+
@test getlogjoint_internal(vi_invlinked) lp
142147
end
143148
end
144149
end
@@ -164,15 +169,15 @@ end
164169
end
165170
@test length(vi_linked[:]) == d - 1
166171
# Should now include the log-absdet-jacobian correction.
167-
@test !(getlogjoint(vi_linked) lp)
172+
@test !(getlogjoint_internal(vi_linked) lp)
168173
# Invlinked.
169174
vi_invlinked = if mutable
170175
DynamicPPL.invlink!!(deepcopy(vi_linked), model)
171176
else
172177
DynamicPPL.invlink(vi_linked, model)
173178
end
174179
@test length(vi_invlinked[:]) == d
175-
@test getlogjoint(vi_invlinked) lp
180+
@test getlogjoint_internal(vi_invlinked) lp
176181
end
177182
end
178183
end

test/logdensityfunction.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ end
2626
loglikelihood(model, vi)
2727

2828
@testset "$(varinfo)" for varinfo in varinfos
29+
# Note use of `getlogjoint` rather than `getlogjoint_internal` here ...
2930
logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo)
3031
θ = varinfo[:]
32+
# ... because it has to match with `logjoint(model, vi)`, which always returns
33+
# the unlinked value
3134
@test LogDensityProblems.logdensity(logdensity, θ) logjoint(model, varinfo)
3235
@test LogDensityProblems.dimension(logdensity) == length(θ)
3336
end

0 commit comments

Comments
 (0)