diff --git a/Project.toml b/Project.toml index b8ed300a..d0e0dcee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InferenceObjects" uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" authors = ["Seth Axen and contributors"] -version = "0.4.13" +version = "0.4.14" [deps] ANSIColoredPrinters = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" @@ -32,7 +32,7 @@ MLJBase = "1" NCDatasets = "0.12.6, 0.13, 0.14" OffsetArrays = "1" OrderedCollections = "1.6" -PosteriorStats = "0.3" +PosteriorStats = "0.3, 0.4" Random = "1" StatsBase = "0.33.7, 0.34" Tables = "1.11.0" diff --git a/ext/InferenceObjectsPosteriorStatsExt/InferenceObjectsPosteriorStatsExt.jl b/ext/InferenceObjectsPosteriorStatsExt/InferenceObjectsPosteriorStatsExt.jl index 0300a647..320dcaa0 100644 --- a/ext/InferenceObjectsPosteriorStatsExt/InferenceObjectsPosteriorStatsExt.jl +++ b/ext/InferenceObjectsPosteriorStatsExt/InferenceObjectsPosteriorStatsExt.jl @@ -6,17 +6,16 @@ using InferenceObjects: InferenceObjects using PosteriorStats: PosteriorStats using StatsBase: StatsBase -import PosteriorStats: eti, hdi, loo, loo_pit, r2_score, summarize, waic import StatsBase: summarystats -export eti, hdi, loo, loo_pit, r2_score, summarize, waic, summarystats - maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map include("utils.jl") include("ci.jl") include("loo.jl") -include("waic.jl") +@static if isdefined(PosteriorStats, :waic) + include("waic.jl") +end include("loo_pit.jl") include("r2_score.jl") include("summarize.jl") diff --git a/ext/InferenceObjectsPosteriorStatsExt/ci.jl b/ext/InferenceObjectsPosteriorStatsExt/ci.jl index f87e4f4d..02baf2f9 100644 --- a/ext/InferenceObjectsPosteriorStatsExt/ci.jl +++ b/ext/InferenceObjectsPosteriorStatsExt/ci.jl @@ -1,16 +1,16 @@ for (ci_fun, ci_desc) in (:eti => "equal-tailed interval (ETI)", :hdi => "highest density interval (HDI)") + ci_name = string(ci_fun) @eval begin - # this pattern ensures that the type is completely specified at compile time @doc """ - $($ci_fun)(data::InferenceData; kwargs...) -> Dataset - $($ci_fun)(data::Dataset; kwargs...) -> Dataset + $($ci_name)(data::InferenceData; kwargs...) -> Dataset + $($ci_name)(data::Dataset; kwargs...) -> Dataset Calculate the $($ci_desc) for each parameter in the data. For more details and a description of the `kwargs`, see - [`PosteriorStats.$($ci_fun)`](@extref). + [`PosteriorStats.$($ci_name)`](@extref). """ function PosteriorStats.$(ci_fun)(data::InferenceObjects.InferenceData; kwargs...) return PosteriorStats.$(ci_fun)(data.posterior; kwargs...) diff --git a/ext/InferenceObjectsPosteriorStatsExt/r2_score.jl b/ext/InferenceObjectsPosteriorStatsExt/r2_score.jl index be1cb97e..2e799c69 100644 --- a/ext/InferenceObjectsPosteriorStatsExt/r2_score.jl +++ b/ext/InferenceObjectsPosteriorStatsExt/r2_score.jl @@ -1,5 +1,5 @@ @doc """ - r2_score(idata::InferenceData; y_name, y_pred_name) -> (; r2, r2_std) + r2_score(idata::InferenceData; y_name, y_pred_name, kwargs...) -> (; r2, ) Compute ``R²`` from `idata`, automatically formatting the predictions to the correct shape. @@ -9,8 +9,8 @@ Compute ``R²`` from `idata`, automatically formatting the predictions to the co the only observed data variable is used. - `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`. If not provided, then `y_name` is used. - -See [`PosteriorStats.r2_score`](@extref) for more details. + - `kwargs...`: Additional keyword arguments to pass to + [`PosteriorStats.r2_score`](@extref). # Examples @@ -19,10 +19,8 @@ julia> using ArviZExampleData, PosteriorStats julia> idata = load_example_data("regression10d"); -julia> r2_score(idata) |> pairs -pairs(::NamedTuple) with 2 entries: - :r2 => 0.998385 - :r2_std => 0.000100621 +julia> r2_score(idata) +(r2 = 0.998384805658226, eti = 0.9982167674001565 .. 0.9985401916739318) ``` """ function PosteriorStats.r2_score( diff --git a/ext/InferenceObjectsPosteriorStatsExt/summarize.jl b/ext/InferenceObjectsPosteriorStatsExt/summarize.jl index 6a838f8b..e39a10f2 100644 --- a/ext/InferenceObjectsPosteriorStatsExt/summarize.jl +++ b/ext/InferenceObjectsPosteriorStatsExt/summarize.jl @@ -33,17 +33,17 @@ julia> data = load_example_data("centered_eight"); julia> summarize(data) SummaryStats - mean std eti94 ess_tail ess_bulk rhat mcse_mean mcse_std - mu 4.2 3.3 -2.11 .. 9.90 622 241 1.03 0.21 0.088 - theta[Choate] 6.4 5.9 -3.05 .. 19.1 937 572 1.01 0.25 0.20 - theta[Deerfield] 5.0 4.9 -4.49 .. 14.2 1214 532 1.01 0.21 0.15 - theta[Phillips Andover] 3.4 5.4 -8.17 .. 12.7 1017 511 1.01 0.23 0.17 - theta[Phillips Exeter] 4.8 5.2 -4.84 .. 14.5 911 572 1.01 0.21 0.21 - theta[Hotchkiss] 3.5 4.8 -6.11 .. 12.0 789 347 1.02 0.25 0.15 - theta[Lawrenceville] 3.7 5.2 -6.62 .. 12.6 957 506 1.01 0.22 0.21 - theta[St. Paul's] 6.5 5.2 -2.38 .. 18.3 1031 528 1.01 0.22 0.15 - theta[Mt. Hermon] 4.8 5.7 -5.52 .. 16.0 1045 538 1.01 0.24 0.23 - tau 4.3 3.0 1.06 .. 11.5 214 128 1.03 0.22 0.14 + mean std eti89 ess_tail ess_bulk rhat mcse_mean mcse_std + mu 4.2 3.3 -1.15 .. 9.15 622 241 1.03 0.21 0.088 + theta[Choate] 6.4 5.9 -1.72 .. 16.6 937 572 1.01 0.25 0.20 + theta[Deerfield] 5.0 4.9 -3.03 .. 12.4 1214 532 1.01 0.21 0.15 + theta[Phillips Andover] 3.4 5.4 -5.69 .. 11.3 1017 511 1.01 0.23 0.17 + theta[Phillips Exeter] 4.8 5.2 -3.08 .. 12.7 911 572 1.01 0.21 0.21 + theta[Hotchkiss] 3.5 4.8 -4.29 .. 10.6 789 347 1.02 0.25 0.15 + theta[Lawrenceville] 3.7 5.2 -4.41 .. 11.4 957 506 1.01 0.22 0.21 + theta[St. Paul's] 6.5 5.2 -1.05 .. 15.8 1031 528 1.01 0.22 0.15 + theta[Mt. Hermon] 4.8 5.7 -3.78 .. 13.9 1045 538 1.01 0.24 0.23 + tau 4.3 3.0 1.27 .. 9.95 214 128 1.03 0.22 0.14 ``` Compute the mean, standard deviation, median, and median absolute deviation of the `theta` diff --git a/test/posteriorstats.jl b/test/posteriorstats.jl index a6888c35..0b3d301f 100644 --- a/test/posteriorstats.jl +++ b/test/posteriorstats.jl @@ -172,7 +172,7 @@ _as_array(x::AbstractArray) = x end end - @testset "waic" begin + isdefined(PosteriorStats, :waic) && @testset "waic" begin @testset for sz in ((1000, 4), (1000, 4, 2), (100, 4, 2, 3)) atol_perm = cbrt(eps())