Skip to content

Commit 7b337bd

Browse files
committed
Fix tests
1 parent 7049125 commit 7b337bd

File tree

3 files changed

+9
-59
lines changed

3 files changed

+9
-59
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ function DynamicPPL.predict(
180180
DynamicPPL.VarInfo(),
181181
(
182182
DynamicPPL.LogPriorAccumulator(),
183-
DynamicPPL.LogJacobianAccumulator(),
184183
DynamicPPL.LogLikelihoodAccumulator(),
185184
DynamicPPL.ValuesAsInModelAccumulator(false),
186185
),
@@ -199,23 +198,9 @@ function DynamicPPL.predict(
199198
varinfo,
200199
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
201200
)
202-
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
203-
varname_vals = mapreduce(
204-
collect,
205-
vcat,
206-
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
207-
)
208-
209-
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
201+
DynamicPPL.ParamsWithStats(varinfo, nothing)
210202
end
211-
212-
chain_result = reduce(
213-
MCMCChains.chainscat,
214-
[
215-
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
216-
chain_idx in 1:size(predictive_samples, 2)
217-
],
218-
)
203+
chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples)
219204
parameter_names = if include_all
220205
MCMCChains.names(chain_result, :parameters)
221206
else
@@ -234,45 +219,6 @@ function DynamicPPL.predict(
234219
)
235220
end
236221

237-
function _predictive_samples_to_arrays(predictive_samples)
238-
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
239-
240-
sample_dicts = map(predictive_samples) do sample
241-
varname_value_pairs = sample.varname_and_values
242-
varnames = map(first, varname_value_pairs)
243-
values = map(last, varname_value_pairs)
244-
for varname in varnames
245-
push!(variable_names_set, varname)
246-
end
247-
248-
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
249-
end
250-
251-
variable_names = collect(variable_names_set)
252-
variable_values = [
253-
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
254-
key in variable_names
255-
]
256-
257-
return variable_names, variable_values
258-
end
259-
260-
function _predictive_samples_to_chains(predictive_samples)
261-
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
262-
variable_names_symbols = map(Symbol, variable_names)
263-
264-
internal_parameters = [:lp]
265-
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
266-
267-
parameter_names = [variable_names_symbols; internal_parameters]
268-
parameter_values = hcat(variable_values, log_probabilities)
269-
parameter_values = MCMCChains.concretize(parameter_values)
270-
271-
return MCMCChains.Chains(
272-
parameter_values, parameter_names, (internals=internal_parameters,)
273-
)
274-
end
275-
276222
"""
277223
returned(model::Model, chain::MCMCChains.Chains)
278224

src/to_chains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct ParamsWithStats{P<:OrderedDict{VarName,Any},S<:NamedTuple}
7171
(
7272
logprior=DynamicPPL.getlogprior(varinfo),
7373
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
74-
logjoint=DynamicPPL.getlogjoint(varinfo),
74+
lp=DynamicPPL.getlogjoint(varinfo),
7575
),
7676
)
7777
end

test/test_util.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for
6262
`n_iters` iterations.
6363
"""
6464
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
65-
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
66-
return DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, varinfos)
65+
vi = VarInfo(model)
66+
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.ValuesAsInModelAccumulator(false),))
67+
ps = [
68+
ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi)), nothing) for _ in 1:n_iters
69+
]
70+
return DynamicPPL.to_chains(MCMCChains.Chains, ps)
6771
end
6872
function make_chain_from_prior(model::Model, n_iters::Int)
6973
return make_chain_from_prior(Random.default_rng(), model, n_iters)

0 commit comments

Comments
 (0)