Skip to content

Commit 70c3dd9

Browse files
committed
Implement ParamsWithStats and to_chains functions
1 parent 1b159a6 commit 70c3dd9

File tree

9 files changed

+310
-29
lines changed

9 files changed

+310
-29
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.2
4+
5+
Added a new exported function, `DynamicPPL.varinfos_to_chains`, which automatically converts a collection of VarInfos to a given Chains type.
6+
37
## 0.38.1
48

59
Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.1"
3+
version = "0.38.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,11 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508+
509+
### Converting VarInfo to chains
510+
511+
The following function is useful for package developers seeking to extend DynamicPPL:
512+
513+
```@docs
514+
DynamicPPL.varinfos_to_chains
515+
```

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,76 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
"""
40+
DynamicPPL.to_chains(
41+
::Type{MCMCChains.Chains},
42+
params_and_stats::AbstractArray{<:ParamsWithStats}
43+
)
44+
45+
Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+
"""
47+
function DynamicPPL.to_chains(
48+
::Type{MCMCChains.Chains},
49+
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
50+
)
51+
# Handle parameters
52+
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
53+
split_dicts = map(params_and_stats) do ps
54+
# Separate into individual VarNames.
55+
vn_leaves_and_vals = if isempty(ps.params)
56+
Tuple{DynamicPPL.VarName,Any}[]
57+
else
58+
iters = map(
59+
AbstractPPL.varname_and_value_leaves,
60+
keys(ps.params),
61+
values(ps.params),
62+
)
63+
mapreduce(collect, vcat, iters)
64+
end
65+
vn_leaves = map(first, vn_leaves_and_vals)
66+
vals = map(last, vn_leaves_and_vals)
67+
for vn_leaf in vn_leaves
68+
push!(all_vn_leaves, vn_leaf)
69+
end
70+
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
71+
end
72+
vn_leaves = collect(all_vn_leaves)
73+
param_vals = [
74+
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
75+
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
76+
]
77+
param_symbols = map(Symbol, vn_leaves)
78+
# Handle statistics
79+
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
80+
for ps in params_and_stats
81+
for k in keys(ps.stats)
82+
push!(stat_keys, k)
83+
end
84+
end
85+
stat_keys = collect(stat_keys)
86+
stat_vals = [
87+
get(params_and_stats[i, j].stats, key, missing) for
88+
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
89+
j in eachindex(axes(params_and_stats, 2))
90+
]
91+
# Construct name map and info
92+
name_map = (internals=stat_keys,)
93+
info = (
94+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
95+
zip(all_vn_leaves, param_symbols)
96+
),
97+
)
98+
# Concatenate parameter and statistic values
99+
vals = cat(param_vals, stat_vals; dims=2)
100+
symbols = vcat(param_symbols, stat_keys)
101+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
102+
end
103+
function DynamicPPL.to_chains(
104+
::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats}
105+
)
106+
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
107+
end
108+
39109
"""
40110
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41111

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129+
# Chain construction
130+
ParamsWithStats,
131+
to_chains,
129132
# Convenience macros
130133
@addlogprob!,
131134
value_iterator_from_chain,
@@ -194,6 +197,7 @@ include("model_utils.jl")
194197
include("extract_priors.jl")
195198
include("values_as_in_model.jl")
196199
include("bijector.jl")
200+
include("to_chains.jl")
197201

198202
include("debug_utils.jl")
199203
using .DebugUtils

src/to_chains.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
ParamsWithStats
3+
4+
A struct which contains parameter values extracted from a `VarInfo`, along with any
5+
statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are
6+
optional.
7+
8+
ParamsWithStats(
9+
varinfo::AbstractVarInfo,
10+
model::Model,
11+
stats::NamedTuple=NamedTuple();
12+
include_colon_eq::Bool=true,
13+
include_log_probs::Bool=true,
14+
)
15+
16+
Generate a `ParamsWithStats` by re-evaluating the given `model` with the provided `varinfo`.
17+
Re-evaluation of the model is often necessary to obtain correct parameter values as well as
18+
log probabilities. This is especially true when using linked VarInfos, i.e., when variables
19+
have been transformed to unconstrained space, and if this is not done, subtle correctness
20+
bugs may arise: see, e.g., https://github.com/TuringLang/Turing.jl/issues/2195.
21+
22+
`include_colon_eq` controls whether variables on the left-hand side of `:=` are included in
23+
the resulting parameters.
24+
25+
`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log
26+
joint) are added to the resulting statistics NamedTuple.
27+
28+
ParamsWithStats(
29+
varinfo::AbstractVarInfo,
30+
::Nothing,
31+
stats::NamedTuple=NamedTuple();
32+
include_log_probs::Bool=true,
33+
)
34+
35+
There is one case where re-evaluation is not necessary, which is when the VarInfos all
36+
already contain `DynamicPPL.ValuesAsInModelAccumulator`. This accumulator stores values
37+
as seen during the model evaluation, so the values can be simply read off. In this case,
38+
`model` can be set to `nothing`, and no re-evaluation will be performed. However, it is the
39+
caller's responsibility to ensure that `ValuesAsInModelAccumulator` is indeed
40+
present.
41+
42+
`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log
43+
joint) are added to the resulting statistics NamedTuple.
44+
"""
45+
struct ParamsWithStats{P<:OrderedDict{VarName,Any},S<:NamedTuple}
46+
params::P
47+
stats::S
48+
49+
function ParamsWithStats(
50+
varinfo::AbstractVarInfo,
51+
model::DynamicPPL.Model,
52+
stats::NamedTuple=NamedTuple();
53+
include_colon_eq::Bool=true,
54+
include_log_probs::Bool=true,
55+
)
56+
accs = if include_log_probs
57+
(
58+
DynamicPPL.LogPriorAccumulator(),
59+
DynamicPPL.LogLikelihoodAccumulator(),
60+
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
61+
)
62+
else
63+
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
64+
end
65+
varinfo = DynamicPPL.setaccs!!(varinfo, accs)
66+
varinfo = last(DynamicPPL.evaluate!!(model, varinfo))
67+
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
68+
if include_log_probs
69+
stats = merge(
70+
stats,
71+
(
72+
logprior=DynamicPPL.getlogprior(varinfo),
73+
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
74+
logjoint=DynamicPPL.getlogjoint(varinfo),
75+
),
76+
)
77+
end
78+
return new{typeof(params),typeof(stats)}(params, stats)
79+
end
80+
81+
function ParamsWithStats(
82+
varinfo::AbstractVarInfo,
83+
::Nothing,
84+
stats::NamedTuple=NamedTuple();
85+
include_log_probs::Bool=true,
86+
)
87+
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
88+
if include_log_probs
89+
has_prior_acc = DynamicPPL.hasacc(varinfo, Val(:LogPrior))
90+
has_likelihood_acc = DynamicPPL.hasacc(varinfo, Val(:LogLikelihood))
91+
if has_prior_acc
92+
stats = merge(stats, (logprior=DynamicPPL.getlogprior(varinfo),))
93+
end
94+
if has_likelihood_acc
95+
stats = merge(stats, (loglikelihood=DynamicPPL.getloglikelihood(varinfo),))
96+
end
97+
if has_prior_acc && has_likelihood_acc
98+
stats = merge(stats, (logjoint=DynamicPPL.getlogjoint(varinfo),))
99+
end
100+
end
101+
return new{typeof(params),typeof(stats)}(params, stats)
102+
end
103+
end
104+
105+
"""
106+
to_chains(
107+
Tout::Type{<:AbstractChains},
108+
params_and_stats::AbstractArray{<:ParamsWithStats}
109+
)
110+
111+
Convert an array of `ParamsWithStats` to a chains object of type `Tout`.
112+
113+
This function is not implemented here but rather in package extensions for individual chains
114+
packages.
115+
"""
116+
function to_chains end

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,43 @@
1111
chain_generated = @test_nowarn returned(model, chain)
1212
@test size(chain_generated) == (1000, 1)
1313
@test mean(chain_generated) 0 atol = 0.1
14+
15+
@testset "varinfos_to_chains" begin
16+
@model function f(z)
17+
x ~ Normal()
18+
y := x + 1
19+
return z ~ Normal(y)
20+
end
21+
22+
z = 1.0
23+
model = f(z)
24+
25+
@testset "vector" begin
26+
ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50]
27+
c = DynamicPPL.to_chains(MCMCChains.Chains, ps)
28+
@test c isa MCMCChains.Chains
29+
@test size(c, 1) == 50
30+
@test size(c, 3) == 1
31+
@test Set(c.name_map.parameters) == Set([:x, :y])
32+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :logjoint])
33+
@test logpdf.(Normal(), c[:x]) c[:logprior]
34+
@test c.info.varname_to_symbol[@varname(x)] == :x
35+
@test c.info.varname_to_symbol[@varname(y)] == :y
36+
end
37+
38+
@testset "matrix" begin
39+
ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50, _ in 1:3]
40+
c = DynamicPPL.to_chains(MCMCChains.Chains, ps)
41+
@test c isa MCMCChains.Chains
42+
@test size(c, 1) == 50
43+
@test size(c, 3) == 3
44+
@test Set(c.name_map.parameters) == Set([:x, :y])
45+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :logjoint])
46+
@test logpdf.(Normal(), c[:x]) c[:logprior]
47+
@test c.info.varname_to_symbol[@varname(x)] == :x
48+
@test c.info.varname_to_symbol[@varname(y)] == :y
49+
end
50+
end
1451
end
1552

1653
# test for `predict` is in `test/model.jl`

test/test_util.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,35 +62,8 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for
6262
`n_iters` iterations.
6363
"""
6464
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
65-
# Sample from the prior
6665
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
67-
# Extract all varnames found in any dictionary. Doing it this way guards
68-
# against the possibility of having different varnames in different
69-
# dictionaries, e.g. for models that have dynamic variables / array sizes
70-
varnames = OrderedSet{VarName}()
71-
# Convert each varinfo into an OrderedDict of vns => params.
72-
# We have to use varname_and_value_leaves so that each parameter is a scalar
73-
dicts = map(varinfos) do t
74-
vals = DynamicPPL.values_as(t, OrderedDict)
75-
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
76-
tuples = mapreduce(collect, vcat, iters)
77-
# The following loop is a replacement for:
78-
# push!(varnames, map(first, tuples)...)
79-
# which causes a stack overflow if `map(first, tuples)` is too large.
80-
# Unfortunately there isn't a union() function for OrderedSet.
81-
for vn in map(first, tuples)
82-
push!(varnames, vn)
83-
end
84-
OrderedDict(tuples)
85-
end
86-
# Convert back to list
87-
varnames = collect(varnames)
88-
# Construct matrix of values
89-
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
90-
# Construct dict of varnames -> symbol
91-
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
92-
# Construct and return the Chains object
93-
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
66+
return DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, varinfos)
9467
end
9568
function make_chain_from_prior(model::Model, n_iters::Int)
9669
return make_chain_from_prior(Random.default_rng(), model, n_iters)

test/to_chains.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
module DynamicPPLToChainsTests
2+
3+
using DynamicPPL
4+
using Distributions
5+
using Test
6+
7+
@testset "ParamsWithStats" begin
8+
@model function f(z)
9+
x ~ Normal()
10+
y := x + 1
11+
return z ~ Normal(y)
12+
end
13+
z = 1.0
14+
model = f(z)
15+
16+
@testset "with reevaluation" begin
17+
ps = ParamsWithStats(VarInfo(model), model)
18+
@test haskey(ps.params, @varname(x))
19+
@test haskey(ps.params, @varname(y))
20+
@test length(ps.params) == 2
21+
@test haskey(ps.stats, :logprior)
22+
@test haskey(ps.stats, :loglikelihood)
23+
@test haskey(ps.stats, :logjoint)
24+
@test length(ps.stats) == 3
25+
@test ps.stats.logjoint ps.stats.logprior + ps.stats.loglikelihood
26+
@test ps.params[@varname(y)] ps.params[@varname(x)] + 1
27+
@test ps.stats.logprior logpdf(Normal(), ps.params[@varname(x)])
28+
@test ps.stats.loglikelihood logpdf(Normal(ps.params[@varname(y)]), z)
29+
end
30+
31+
@testset "without colon_eq" begin
32+
ps = ParamsWithStats(VarInfo(model), model; include_colon_eq=false)
33+
@test haskey(ps.params, @varname(x))
34+
@test length(ps.params) == 1
35+
@test haskey(ps.stats, :logprior)
36+
@test haskey(ps.stats, :loglikelihood)
37+
@test haskey(ps.stats, :logjoint)
38+
@test length(ps.stats) == 3
39+
@test ps.stats.logjoint ps.stats.logprior + ps.stats.loglikelihood
40+
@test ps.stats.logprior logpdf(Normal(), ps.params[@varname(x)])
41+
@test ps.stats.loglikelihood logpdf(Normal(ps.params[@varname(x)] + 1), z)
42+
end
43+
44+
@testset "without log probs" begin
45+
ps = ParamsWithStats(VarInfo(model), model; include_log_probs=false)
46+
@test haskey(ps.params, @varname(x))
47+
@test haskey(ps.params, @varname(y))
48+
@test length(ps.params) == 2
49+
@test isempty(ps.stats)
50+
end
51+
52+
@testset "no reevaluation" begin
53+
# Without VAIM, it should error
54+
@test_throws ErrorException ParamsWithStats(VarInfo(model), nothing)
55+
# With VAIM, it should work
56+
vi = DynamicPPL.setaccs!!(
57+
VarInfo(model), (DynamicPPL.ValuesAsInModelAccumulator(true),)
58+
)
59+
vi = last(DynamicPPL.evaluate!!(model, vi))
60+
ps = ParamsWithStats(vi, nothing)
61+
@test haskey(ps.params, @varname(x))
62+
@test haskey(ps.params, @varname(y))
63+
@test length(ps.params) == 2
64+
# Because we didn't evaluate with log prob accumulators, there should be no stats
65+
@test isempty(ps.stats)
66+
end
67+
end
68+
69+
end # module

0 commit comments

Comments
 (0)