-
-
Notifications
You must be signed in to change notification settings - Fork 1
Description
The natural way to represent a draw from a posterior distribution is as a NamedTuple
whose keys are parameter names and whose values are the values. The values can be scalars, arrays, or arbitrary Julia objects. Then all draws for a chain are a vector of such NamedTuple
s, and we may have a vector of chains. When we convert to InferenceData
, we would "flatten" until we get numeric arrays. Each element of such an array is a marginal draw, and this is useful for plotting and diagnostics.
Sometimes though users need the unflattened draws; e.g., when interacting with the PPL, one often needs draws in a format produced by the PPL, which will in general not look like a Dataset
. In #11 we discuss ideas for not flattening. A simpler alternative is to provide utility functions for "unflattening". Here's an example of such a function:
julia> using DimensionalData, InferenceObjects
julia> function unflatten(f, v, keep_dims=(:chain, :draw))
dims = Dimensions.otherdims(v, keep_dims)
isempty(dims) && return v
keep_dims_actual = Dimensions.otherdims(v, dims)
dimnums = Dimensions.dimnum(v, dims)
data_new = dropdims(mapslices(Base.vect ∘ f, parent(v); dims=dimnums); dims=dimnums)
return DimArray(data_new, keep_dims_actual)
end;
By passing f=identity
, we can handle the case where draws are scalars or arrays of scalars:
julia> x = convert_to_dataset((; x=randn(2, 3, 8, 4)); dims=(x=[:a, :b],)).x
2×3×8×4 DimArray{Float64,4} x with dimensions:
Dim{:a} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points,
Dim{:b} Sampled{Int64} Base.OneTo(3) ForwardOrdered Regular Points,
Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1, 1]
1 2 3
1 0.844121 1.79069 -0.349435
2 -0.435955 2.21937 0.102086
[and 31 more slices...]
julia> x_unflat = unflatten(identity, x)
8×4 DimArray{Matrix{Float64},2} with dimensions:
Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
… 4
1 [-0.710667 2.17712 1.25004; 1.15662 0.138343 0.511868]
2 [-1.59593 -0.847627 0.185637; 1.62011 0.733101 -0.82679]
3 [1.6434 0.188635 0.434926; -2.70953 0.223494 -1.00055]
4 [-0.753704 -2.25251 0.32903; 1.97774 -0.744595 1.0287]
5 … [0.837521 -0.252849 0.0989726; -1.10382 0.511166 0.566629]
6 [-1.58429 -0.164573 1.83263; 0.875992 -0.174146 -1.10488]
7 [-2.21422 -0.398891 -1.26135; 1.27395 -0.150042 0.243492]
8 [0.789781 0.052268 -1.51552; 0.5554 1.08581 -1.16574]
julia> x_unflat[1]
2×3 Matrix{Float64}:
0.844121 1.79069 -0.349435
-0.435955 2.21937 0.102086
Other f
s let us handle cases where draws are not array types. For example, here's how we might unflatten a real array representing complex draws:
julia> z = convert_to_dataset((; z=randn(2, 8, 4)); dims=(z=[:reim],), coords=(reim=[:re, :im],)).z
2×8×4 DimArray{Float64,3} z with dimensions:
Dim{:reim} Categorical{Symbol} Symbol[re, im] ReverseOrdered,
Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1]
1 2 3 4 5 6 7 8
:re 0.149895 -0.758902 -0.162169 -1.58568 -1.9113 -0.873895 -1.15336 -0.723117
:im -0.0615223 -0.191197 0.552402 0.754498 -0.139014 0.496133 1.69164 -1.05489
[and 3 more slices...]
julia> z_unflat = unflatten(Base.splat(complex), z)
8×4 DimArray{ComplexF64,2} with dimensions:
Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
1 2 3 4
1 0.149895-0.0615223im -1.49337+0.0310133im -0.186236-0.632437im 0.122908-1.8747im
2 -0.758902-0.191197im 0.0847507+1.8477im 0.699646+0.0940246im -0.700787+0.589689im
3 -0.162169+0.552402im -0.426661-0.215763im 1.24455-0.30482im 0.87671-0.0396714im
4 -1.58568+0.754498im -1.08887-0.0911398im -1.18796+0.0439568im 0.583836+0.226613im
5 -1.9113-0.139014im -1.11748+0.521976im -0.453853-0.668656im -1.40155+0.216688im
6 -0.873895+0.496133im 0.471934+0.508555im -1.1003-0.844055im 2.6073-0.25573im
7 -1.15336+1.69164im 0.107038+0.070659im -2.15358-1.19693im -0.0646238-0.749879im
8 -0.723117-1.05489im 1.0455-0.601896im -0.931837+0.621233im 0.789712+0.442579im
By applying this approach to all parameters in a Dataset
, we can unflatten everything:
julia> using ArviZExampleData
julia> idata = load_example_data("centered_eight");
julia> post = idata.posterior
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
:mu Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
:theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
:tau Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
with metadata Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
julia> post_new = Dataset(map(v -> unflatten(identity, v), NamedTuple(post)))
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points
and 3 layers:
:mu Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
:theta Vector{Float64} dims: Dim{:draw}, Dim{:chain} (500×4)
:tau Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
julia> post_new[1]
(mu = 7.871796366146925, theta = [12.320685578094814, 9.905366892588605, 14.9516154956564, 11.011484941973162, 5.5796015919074735, 16.901795293711004, 13.198059333176934, 15.06136583596694], tau = 4.725740062893666)
I propose we add something like this utility to the API to make it easier to use InferenceObjects with PPLs.