Skip to content

Utility for unflattening Datasets #27

@sethaxen

Description

@sethaxen

The natural way to represent a draw from a posterior distribution is as a NamedTuple whose keys are parameter names and whose values are the values. The values can be scalars, arrays, or arbitrary Julia objects. Then all draws for a chain are a vector of such NamedTuples, and we may have a vector of chains. When we convert to InferenceData, we would "flatten" until we get numeric arrays. Each element of such an array is a marginal draw, and this is useful for plotting and diagnostics.

Sometimes though users need the unflattened draws; e.g., when interacting with the PPL, one often needs draws in a format produced by the PPL, which will in general not look like a Dataset. In #11 we discuss ideas for not flattening. A simpler alternative is to provide utility functions for "unflattening". Here's an example of such a function:

julia> using DimensionalData, InferenceObjects

julia> function unflatten(f, v, keep_dims=(:chain, :draw))
           dims = Dimensions.otherdims(v, keep_dims)
           isempty(dims) && return v
           keep_dims_actual = Dimensions.otherdims(v, dims)
           dimnums = Dimensions.dimnum(v, dims)
           data_new = dropdims(mapslices(Base.vect  f, parent(v); dims=dimnums); dims=dimnums)
           return DimArray(data_new, keep_dims_actual)
       end;

By passing f=identity, we can handle the case where draws are scalars or arrays of scalars:

julia> x = convert_to_dataset((; x=randn(2, 3, 8, 4)); dims=(x=[:a, :b],)).x
2×3×8×4 DimArray{Float64,4} x with dimensions: 
  Dim{:a} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points,
  Dim{:b} Sampled{Int64} Base.OneTo(3) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1, 1]
     1         2         3
 1   0.844121  1.79069  -0.349435
 2  -0.435955  2.21937   0.102086
[and 31 more slices...]

julia> x_unflat = unflatten(identity, x)
8×4 DimArray{Matrix{Float64},2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
      4
 1      [-0.710667 2.17712 1.25004; 1.15662 0.138343 0.511868]
 2      [-1.59593 -0.847627 0.185637; 1.62011 0.733101 -0.82679]
 3      [1.6434 0.188635 0.434926; -2.70953 0.223494 -1.00055]
 4      [-0.753704 -2.25251 0.32903; 1.97774 -0.744595 1.0287]
 5     [0.837521 -0.252849 0.0989726; -1.10382 0.511166 0.566629]
 6      [-1.58429 -0.164573 1.83263; 0.875992 -0.174146 -1.10488]
 7      [-2.21422 -0.398891 -1.26135; 1.27395 -0.150042 0.243492]
 8      [0.789781 0.052268 -1.51552; 0.5554 1.08581 -1.16574]

julia> x_unflat[1]
2×3 Matrix{Float64}:
  0.844121  1.79069  -0.349435
 -0.435955  2.21937   0.102086

Other fs let us handle cases where draws are not array types. For example, here's how we might unflatten a real array representing complex draws:

julia> z = convert_to_dataset((; z=randn(2, 8, 4)); dims=(z=[:reim],), coords=(reim=[:re, :im],)).z
2×8×4 DimArray{Float64,3} z with dimensions: 
  Dim{:reim} Categorical{Symbol} Symbol[re, im] ReverseOrdered,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1]
        1           2          3          4          5          6          7         8
  :re   0.149895   -0.758902  -0.162169  -1.58568   -1.9113    -0.873895  -1.15336  -0.723117
  :im  -0.0615223  -0.191197   0.552402   0.754498  -0.139014   0.496133   1.69164  -1.05489
[and 3 more slices...]

julia> z_unflat = unflatten(Base.splat(complex), z)
8×4 DimArray{ComplexF64,2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
             1                      2                      3                       4
 1   0.149895-0.0615223im   -1.49337+0.0310133im  -0.186236-0.632437im     0.122908-1.8747im
 2  -0.758902-0.191197im   0.0847507+1.8477im      0.699646+0.0940246im   -0.700787+0.589689im
 3  -0.162169+0.552402im   -0.426661-0.215763im     1.24455-0.30482im       0.87671-0.0396714im
 4   -1.58568+0.754498im    -1.08887-0.0911398im   -1.18796+0.0439568im    0.583836+0.226613im
 5    -1.9113-0.139014im    -1.11748+0.521976im   -0.453853-0.668656im     -1.40155+0.216688im
 6  -0.873895+0.496133im    0.471934+0.508555im     -1.1003-0.844055im       2.6073-0.25573im
 7   -1.15336+1.69164im     0.107038+0.070659im    -2.15358-1.19693im    -0.0646238-0.749879im
 8  -0.723117-1.05489im       1.0455-0.601896im   -0.931837+0.621233im     0.789712+0.442579im

By applying this approach to all parameters in a Dataset, we can unflatten everything:

julia> using ArviZExampleData

julia> idata = load_example_data("centered_eight");

julia> post = idata.posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, , St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> post_new = Dataset(map(v -> unflatten(identity, v), NamedTuple(post)))
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Vector{Float64} dims: Dim{:draw}, Dim{:chain} (500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)


julia> post_new[1]
(mu = 7.871796366146925, theta = [12.320685578094814, 9.905366892588605, 14.9516154956564, 11.011484941973162, 5.5796015919074735, 16.901795293711004, 13.198059333176934, 15.06136583596694], tau = 4.725740062893666)

I propose we add something like this utility to the API to make it easier to use InferenceObjects with PPLs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions