Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.4.13"
version = "0.4.14"

[deps]
ANSIColoredPrinters = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
Expand Down Expand Up @@ -32,7 +32,7 @@ MLJBase = "1"
NCDatasets = "0.12.6, 0.13, 0.14"
OffsetArrays = "1"
OrderedCollections = "1.6"
PosteriorStats = "0.3"
PosteriorStats = "0.3, 0.4"
Random = "1"
StatsBase = "0.33.7, 0.34"
Tables = "1.11.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@ using InferenceObjects: InferenceObjects
using PosteriorStats: PosteriorStats
using StatsBase: StatsBase

import PosteriorStats: eti, hdi, loo, loo_pit, r2_score, summarize, waic
import StatsBase: summarystats

export eti, hdi, loo, loo_pit, r2_score, summarize, waic, summarystats

maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map

include("utils.jl")
include("ci.jl")
include("loo.jl")
include("waic.jl")
@static if isdefined(PosteriorStats, :waic)
include("waic.jl")
end
include("loo_pit.jl")
include("r2_score.jl")
include("summarize.jl")
Expand Down
8 changes: 4 additions & 4 deletions ext/InferenceObjectsPosteriorStatsExt/ci.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@

for (ci_fun, ci_desc) in
(:eti => "equal-tailed interval (ETI)", :hdi => "highest density interval (HDI)")
ci_name = string(ci_fun)
@eval begin
# this pattern ensures that the type is completely specified at compile time
@doc """
$($ci_fun)(data::InferenceData; kwargs...) -> Dataset
$($ci_fun)(data::Dataset; kwargs...) -> Dataset
$($ci_name)(data::InferenceData; kwargs...) -> Dataset
$($ci_name)(data::Dataset; kwargs...) -> Dataset

Calculate the $($ci_desc) for each parameter in the data.

For more details and a description of the `kwargs`, see
[`PosteriorStats.$($ci_fun)`](@extref).
[`PosteriorStats.$($ci_name)`](@extref).
"""
function PosteriorStats.$(ci_fun)(data::InferenceObjects.InferenceData; kwargs...)
return PosteriorStats.$(ci_fun)(data.posterior; kwargs...)
Expand Down
12 changes: 5 additions & 7 deletions ext/InferenceObjectsPosteriorStatsExt/r2_score.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc """
r2_score(idata::InferenceData; y_name, y_pred_name) -> (; r2, r2_std)
r2_score(idata::InferenceData; y_name, y_pred_name, kwargs...) -> (; r2, <ci>)

Compute ``R²`` from `idata`, automatically formatting the predictions to the correct shape.

Expand All @@ -9,8 +9,8 @@ Compute ``R²`` from `idata`, automatically formatting the predictions to the co
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.

See [`PosteriorStats.r2_score`](@extref) for more details.
- `kwargs...`: Additional keyword arguments to pass to
[`PosteriorStats.r2_score`](@extref).

# Examples

Expand All @@ -19,10 +19,8 @@ julia> using ArviZExampleData, PosteriorStats

julia> idata = load_example_data("regression10d");

julia> r2_score(idata) |> pairs
pairs(::NamedTuple) with 2 entries:
:r2 => 0.998385
:r2_std => 0.000100621
julia> r2_score(idata)
(r2 = 0.998384805658226, eti = 0.9982167674001565 .. 0.9985401916739318)
```
"""
function PosteriorStats.r2_score(
Expand Down
22 changes: 11 additions & 11 deletions ext/InferenceObjectsPosteriorStatsExt/summarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ julia> data = load_example_data("centered_eight");

julia> summarize(data)
SummaryStats
mean std eti94 ess_tail ess_bulk rhat mcse_mean mcse_std
mu 4.2 3.3 -2.11 .. 9.90 622 241 1.03 0.21 0.088
theta[Choate] 6.4 5.9 -3.05 .. 19.1 937 572 1.01 0.25 0.20
theta[Deerfield] 5.0 4.9 -4.49 .. 14.2 1214 532 1.01 0.21 0.15
theta[Phillips Andover] 3.4 5.4 -8.17 .. 12.7 1017 511 1.01 0.23 0.17
theta[Phillips Exeter] 4.8 5.2 -4.84 .. 14.5 911 572 1.01 0.21 0.21
theta[Hotchkiss] 3.5 4.8 -6.11 .. 12.0 789 347 1.02 0.25 0.15
theta[Lawrenceville] 3.7 5.2 -6.62 .. 12.6 957 506 1.01 0.22 0.21
theta[St. Paul's] 6.5 5.2 -2.38 .. 18.3 1031 528 1.01 0.22 0.15
theta[Mt. Hermon] 4.8 5.7 -5.52 .. 16.0 1045 538 1.01 0.24 0.23
tau 4.3 3.0 1.06 .. 11.5 214 128 1.03 0.22 0.14
mean std eti89 ess_tail ess_bulk rhat mcse_mean mcse_std
mu 4.2 3.3 -1.15 .. 9.15 622 241 1.03 0.21 0.088
theta[Choate] 6.4 5.9 -1.72 .. 16.6 937 572 1.01 0.25 0.20
theta[Deerfield] 5.0 4.9 -3.03 .. 12.4 1214 532 1.01 0.21 0.15
theta[Phillips Andover] 3.4 5.4 -5.69 .. 11.3 1017 511 1.01 0.23 0.17
theta[Phillips Exeter] 4.8 5.2 -3.08 .. 12.7 911 572 1.01 0.21 0.21
theta[Hotchkiss] 3.5 4.8 -4.29 .. 10.6 789 347 1.02 0.25 0.15
theta[Lawrenceville] 3.7 5.2 -4.41 .. 11.4 957 506 1.01 0.22 0.21
theta[St. Paul's] 6.5 5.2 -1.05 .. 15.8 1031 528 1.01 0.22 0.15
theta[Mt. Hermon] 4.8 5.7 -3.78 .. 13.9 1045 538 1.01 0.24 0.23
tau 4.3 3.0 1.27 .. 9.95 214 128 1.03 0.22 0.14
```

Compute the mean, standard deviation, median, and median absolute deviation of the `theta`
Expand Down
2 changes: 1 addition & 1 deletion test/posteriorstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ _as_array(x::AbstractArray) = x
end
end

@testset "waic" begin
isdefined(PosteriorStats, :waic) && @testset "waic" begin
@testset for sz in ((1000, 4), (1000, 4, 2), (100, 4, 2, 3))
atol_perm = cbrt(eps())

Expand Down
Loading