-
-
Notifications
You must be signed in to change notification settings - Fork 1
Description
Currently we have two classes of conversion functions. The first class consists of from_XXX functions, which dispatch on the posterior type and have a number of keywords specific to that type. e.g. from_namedtuple or from_mcmcchains. Then we have the generic functions convert_to_inference_data and convert_to_dataset, which have methods that dispatch to the from_XXX functions. These functions can even be used within other from_XXX functions to allow groups of one type to be mixed with a posterior of another type.
When a user wants their type to be convertible to an InferenceData, they in general implement a from_XXX function and a special convert_to_inference_data method.
Here I propose a major design of this pipeline. Here are some principles we use:
- There are 2 types of objects we want to convert: objects that contain data for a single group (or part of a group) and objects that contain data for multiple groups.
- A user may want to split the data in the first type of object into several groups.
- We drop the prefix
convert_to_, since we're not in general doing conversion. - The first type of object can be hooked into the pipeline by implementing a single function
dataset. - The second type of object can be hooked into the pipeline by implementing the functions
inferencedataanddataset. - All conversion functions should absorb unused keywords into
kwargs, so that a singleinferencedatacall can use keywords for multiple conversion methods so long as they don't clash.
Working prototype of the pipeline
using InferenceObjects
# fallback to current pipeline for demonstration purposes
dataset(data; kwargs...) = convert_to_dataset(data; kwargs...)
inferencedata(data::InferenceData; kwargs...) = data
inferencedata(data; kwargs...) = inferencedata(:posterior => data; kwargs...)
function inferencedata(data::Pair{Symbol}; kwargs...)
k, v = data
ds = if k ∈ (:constant_data, :observed_data)
dataset(v; default_dims=(), kwargs...)
else
dataset(v; kwargs...)
end
return InferenceData(; k => ds)
end
function inferencedata(data, next::Pair{Symbol}, others::Pair{Symbol}...; kwargs...)
inferencedata(inferencedata(data; kwargs...), next, others...; kwargs...)
end
function inferencedata(data::InferenceData, next::Pair{Symbol}, others::Pair{Symbol}...; kwargs...)
merge(data, inferencedata(next; kwargs...), others...; kwargs...)
end
struct Subset{V}
source::Symbol
var_map::V
end
function subset(source::Symbol, var_map::Tuple{Vararg{Union{Symbol,Pair{Symbol,Symbol}}}})
var_map_new = map(var_map) do v
v isa Pair && return v
return v => v
end
return Subset(source, var_map_new)
end
function inferencedata(data::InferenceData, next::Pair{Symbol,<:Subset}, others::Pair{Symbol}...; kwargs...)
k, s = next
source_vars = map(last, s.var_map)
source_ds = data[s.source]
source_ds_new = source_ds[filter(∉(source_vars), keys(source_ds))]
subset_nt = NamedTuple(source_ds[source_vars])
subset = Dataset(NamedTuple{map(first, s.var_map)}(values(subset_nt)))
idata_merged = merge(data, InferenceData(; s.source => source_ds_new, k => subset))
return inferencedata(idata_merged, others...; kwargs...)
endDemonstration
Now here's a demonstration of how we use it:
julia> ndraws, nchains = 1_000, 4;
julia> data_all = (
x = randn(4, ndraws, nchains),
z = randn(2, ndraws, nchains),
lp = randn(ndraws, nchains),
log_like = randn(10, ndraws, nchains),
y_hat = randn(10, ndraws, nchains),
);
julia> idata = inferencedata(
data_all,
:posterior_predictive => subset(:posterior, (:y => :y_hat,)),
:log_likelihood => subset(:posterior, (:y => :log_like,)),
:sample_stats => subset(:posterior, (:lp,)),
)
InferenceData with groups:
> posterior
> posterior_predictive
> log_likelihood
> sample_stats
julia> idata.posterior
Dataset with dimensions:
Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
:x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
:z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)
with metadata Dict{String, Any} with 1 entry:
"created_at" => "2022-11-03T23:09:29.868"
julia> idata.posterior_predictive
Dataset with dimensions:
Dim{:y_hat_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
:y Float64 dims: Dim{:y_hat_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)
julia> idata.log_likelihood
Dataset with dimensions:
Dim{:log_like_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
:y Float64 dims: Dim{:log_like_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)
julia> idata.sample_stats
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
:lp Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)The subseting machinery is generic, so we don't need to customize it for every type like we currently do in the from_XXX methods.
There are still some kinks to work out in this pipeline, like correct handling of dimensions when the variables are renamed, but let's check the extensibility of the pipeline.
Demonstration of pipeline extensibility
Here we define two types of storage of MCMC results, representing the two types defined above.
# represents some object containing data from a single dataset
struct PosteriorStorage
nt
end
dataset(post::PosteriorStorage; kwargs...) = dataset(post.nt; kwargs...)
# represents some object containing data from multiple datasets, here a posterior and sample_stats
# we allow it to be converted to an InferenceData or to a dataset, in which case a single Dataset is extracted, here the posterior
# e.g. MCMCChains.Chains or SampleChains.MultiChain
struct MultiGroupStorage
nt
end
function inferencedata(post::MultiGroupStorage; kwargs...)
inferencedata(post.nt, :sample_stats=>subset(:posterior, (:lp,)); kwargs...);
end
dataset(post::MultiGroupStorage; kwargs...) = inferencedata(post; kwargs...).posteriorNow let's wrap our NamedTuple in these types and execute the pipeline:
julia> idata2 = inferencedata(
PosteriorStorage(data_all),
:posterior_predictive => subset(:posterior, (:y => :y_hat,)),
:log_likelihood => subset(:posterior, (:y => :log_like,)),
:sample_stats => subset(:posterior, (:lp,)),
)
InferenceData with groups:
> posterior
> posterior_predictive
> log_likelihood
> sample_stats
julia> idata2.posterior
Dataset with dimensions:
Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
:x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
:z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)
with metadata Dict{String, Any} with 1 entry:
"created_at" => "2022-11-03T23:23:57.729"
julia> idata2.posterior_predictive
Dataset with dimensions:
Dim{:y_hat_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
:y Float64 dims: Dim{:y_hat_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)
julia> idata3 = inferencedata(
MultiGroupStorage(data_all),
:posterior_predictive => subset(:posterior, (:y => :y_hat,)),
:log_likelihood => subset(:posterior, (:y => :log_like,)),
)
InferenceData with groups:
> posterior
> posterior_predictive
> log_likelihood
> sample_stats
julia> idata3.posterior
Dataset with dimensions:
Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
:x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
:z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)
with metadata Dict{String, Any} with 1 entry:
"created_at" => "2022-11-03T23:24:29.005"
julia> idata3.sample_stats
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
:lp Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)