|
1 | 1 | """ArviZ stats accessors.""" |
2 | 2 |
|
3 | 3 | import warnings |
| 4 | +from collections.abc import Hashable |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import xarray as xr |
@@ -214,27 +215,40 @@ def _process_input(self, group, method, allow_non_matching=True): |
214 | 215 | ) |
215 | 216 | return self._obj |
216 | 217 | raise ValueError( |
217 | | - f"Group {group} not available in DataTree. Present groups are {self._obj.children}" |
| 218 | + f"Group {group} not available in DataTree. Present groups are {self._obj.children} " |
| 219 | + f"and the DataTree itself is named {self._obs.name}" |
218 | 220 | ) |
219 | 221 |
|
220 | 222 | def _apply(self, fun_name, group, **kwargs): |
221 | | - allow_non_matching = False |
222 | | - if isinstance(group, str): |
| 223 | + hashable_group = False |
| 224 | + if isinstance(group, Hashable): |
223 | 225 | group = [group] |
224 | | - allow_non_matching = True |
225 | | - return DataTree.from_dict( |
| 226 | + hashable_group = True |
| 227 | + out_dt = DataTree.from_dict( |
226 | 228 | { |
227 | 229 | group_i: xr.Dataset( |
228 | 230 | { |
229 | 231 | var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs)) |
230 | 232 | for var_name, da in self._process_input( |
231 | | - group_i, fun_name, allow_non_matching=allow_non_matching |
| 233 | + # if group is a single str/hashable that doesn't match the group |
| 234 | + # name, still allow it and apply the function to the top level of |
| 235 | + # the provided datatree |
| 236 | + group_i, |
| 237 | + fun_name, |
| 238 | + allow_non_matching=hashable_group, |
232 | 239 | ).items() |
233 | 240 | } |
234 | 241 | ) |
235 | 242 | for group_i in group |
236 | 243 | } |
237 | 244 | ) |
| 245 | + if hashable_group: |
| 246 | + # if group was a string/hashable, return a datatree with a single node |
| 247 | + # (from the provided group) as the root of the DataTree |
| 248 | + return out_dt[group[0]] |
| 249 | + # if group was a sequence, return a DataTree with multiple groups in the 1st level, |
| 250 | + # as many groups as requested |
| 251 | + return out_dt |
238 | 252 |
|
239 | 253 | def filter_vars(self, group="posterior", var_names=None, filter_vars=None): |
240 | 254 | """Access and filter variables of the provided group.""" |
|
0 commit comments