Skip to content

Commit feb3084

Browse files
authored
Improve behaviour of DataTree accessor (#32)
* better default for datatree group handling * add test for datatree accessor group behaviour * lint
1 parent d622d99 commit feb3084

File tree

8 files changed

+80
-13
lines changed

8 files changed

+80
-13
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.6.0
3+
rev: v5.0.0
44
hooks:
55
- id: check-added-large-files
66
- id: check-toml
@@ -11,7 +11,7 @@ repos:
1111
- id: trailing-whitespace
1212

1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.4.9
14+
rev: v0.6.9
1515
hooks:
1616
- id: ruff
1717
args: [ --fix, --exit-non-zero-on-fix ]

src/arviz_stats/accessors.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""ArviZ stats accessors."""
22

33
import warnings
4+
from collections.abc import Hashable
45

56
import numpy as np
67
import xarray as xr
@@ -214,27 +215,40 @@ def _process_input(self, group, method, allow_non_matching=True):
214215
)
215216
return self._obj
216217
raise ValueError(
217-
f"Group {group} not available in DataTree. Present groups are {self._obj.children}"
218+
f"Group {group} not available in DataTree. Present groups are {self._obj.children} "
219+
f"and the DataTree itself is named {self._obs.name}"
218220
)
219221

220222
def _apply(self, fun_name, group, **kwargs):
221-
allow_non_matching = False
222-
if isinstance(group, str):
223+
hashable_group = False
224+
if isinstance(group, Hashable):
223225
group = [group]
224-
allow_non_matching = True
225-
return DataTree.from_dict(
226+
hashable_group = True
227+
out_dt = DataTree.from_dict(
226228
{
227229
group_i: xr.Dataset(
228230
{
229231
var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs))
230232
for var_name, da in self._process_input(
231-
group_i, fun_name, allow_non_matching=allow_non_matching
233+
# if group is a single str/hashable that doesn't match the group
234+
# name, still allow it and apply the function to the top level of
235+
# the provided datatree
236+
group_i,
237+
fun_name,
238+
allow_non_matching=hashable_group,
232239
).items()
233240
}
234241
)
235242
for group_i in group
236243
}
237244
)
245+
if hashable_group:
246+
# if group was a string/hashable, return a datatree with a single node
247+
# (from the provided group) as the root of the DataTree
248+
return out_dt[group[0]]
249+
# if group was a sequence, return a DataTree with multiple groups in the 1st level,
250+
# as many groups as requested
251+
return out_dt
238252

239253
def filter_vars(self, group="posterior", var_names=None, filter_vars=None):
240254
"""Access and filter variables of the provided group."""

tests/base/test_diagnostics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
import pytest
99
from arviz_base import load_arviz_data, xarray_var_iter
10+
1011
from arviz_stats.base import array_stats
1112

1213
# For tests only, recommended value should be closer to 1.01-1.05

tests/base/test_stats_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# pylint: disable=no-member,unnecessary-lambda-assignment
55
import numpy as np
66
import pytest
7-
from arviz_stats.base.stats_utils import logsumexp as _logsumexp
8-
from arviz_stats.base.stats_utils import make_ufunc, not_valid
97
from numpy.testing import assert_array_almost_equal
108
from scipy.special import logsumexp
119

10+
from arviz_stats.base.stats_utils import logsumexp as _logsumexp
11+
from arviz_stats.base.stats_utils import make_ufunc, not_valid
12+
1213

1314
@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64])
1415
@pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)])

tests/test_accessors.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# pylint: disable=redefined-outer-name
2+
"""Test accessors.
3+
4+
Accessor methods are very short, with the bulk of the computation/processing
5+
handled by private methods. Testing this shared infrastructural methods
6+
is the main goal of this module even if it does so via specific "regular" methods.
7+
"""
8+
9+
import numpy as np
10+
import pytest
11+
from arviz_base import from_dict
12+
from datatree import DataTree
13+
14+
15+
@pytest.fixture(scope="module")
16+
def idata():
17+
return from_dict(
18+
{
19+
"posterior": {
20+
"a": np.random.normal(size=(4, 100)),
21+
"b": np.random.normal(size=(4, 100, 3)),
22+
},
23+
"posterior_predictive": {
24+
"y": np.random.normal(size=(4, 100, 7)),
25+
},
26+
}
27+
)
28+
29+
30+
def test_accessors_available(idata):
31+
assert hasattr(idata, "azstats")
32+
assert hasattr(idata.posterior.ds, "azstats")
33+
assert hasattr(idata.posterior["a"], "azstats")
34+
35+
36+
def test_datatree_single_group(idata):
37+
out = idata.azstats.ess(group="posterior")
38+
assert isinstance(out, DataTree)
39+
assert not out.children
40+
assert out.name == "posterior"
41+
42+
43+
def test_datatree_multiple_groups(idata):
44+
out = idata.azstats.ess(group=["posterior", "posterior_predictive"])
45+
assert isinstance(out, DataTree)
46+
assert len(out.children) == 2
47+
assert "posterior" in out.children
48+
assert "posterior_predictive" in out.children

tests/test_psense.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22

33
from arviz_base import convert_to_datatree
4-
from arviz_stats import psense, psense_summary
54
from numpy import isclose
65
from numpy.testing import assert_almost_equal
76

7+
from arviz_stats import psense, psense_summary
8+
89
file_path = os.path.join(os.path.dirname(__file__), "univariate_normal.nc")
910
uni_dt = convert_to_datatree(file_path)
1011

tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77
from arviz_base import from_dict, rcParams
8+
89
from arviz_stats.base.dataarray import dataarray_stats
910
from arviz_stats.utils import ELPDData, get_function, get_log_likelihood
1011

tox.ini

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ isolated_build_env = build
99

1010
[gh-actions]
1111
python =
12-
3.10: check, py310
12+
3.10: py310
1313
3.11: py311
14-
3.12: py312
14+
3.12: py312, check
15+
3.13: py313
1516

1617
[testenv]
1718
basepython =

0 commit comments

Comments
 (0)