Skip to content

Commit 5b110df

Browse files
authored
get_bins dataarray version (#21)
* get_bins dataarray draft * behaviour fixes * pre-commit * remove repeated .items() * ruff
1 parent feb3084 commit 5b110df

File tree

4 files changed

+86
-33
lines changed

4 files changed

+86
-33
lines changed

src/arviz_stats/accessors.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,29 @@ def update_kwargs_with_dims(da, kwargs):
3535
return kwargs
3636

3737

38+
def check_var_name_subset(obj, var_name):
39+
if isinstance(obj, xr.Dataset):
40+
return obj[var_name]
41+
if isinstance(obj, DataTree):
42+
return obj.ds[var_name]
43+
return obj
44+
45+
46+
def apply_function_to_dataset(func, ds, kwargs):
47+
return xr.Dataset(
48+
{
49+
var_name: func(
50+
da,
51+
**{
52+
key: check_var_name_subset(value, var_name)
53+
for key, value in update_kwargs_with_dims(da, kwargs).items()
54+
},
55+
)
56+
for var_name, da in ds.items()
57+
}
58+
)
59+
60+
3861
unset = UnsetDefault()
3962

4063

@@ -121,12 +144,7 @@ def _apply(self, fun, **kwargs):
121144
"""Apply a function to all variables subsetting dims to existing dimensions."""
122145
if isinstance(fun, str):
123146
fun = get_function(fun)
124-
return xr.Dataset(
125-
{
126-
var_name: fun(da, **update_kwargs_with_dims(da, kwargs))
127-
for var_name, da in self._obj.items()
128-
}
129-
)
147+
return apply_function_to_dataset(fun, self._obj, kwargs=kwargs)
130148

131149
def eti(self, prob=None, dims=None, **kwargs):
132150
"""Compute the equal tail interval of all the variables in the dataset."""
@@ -154,8 +172,12 @@ def kde(self, dims=None, **kwargs):
154172
"""Compute the KDE for all variables in the dataset."""
155173
return self._apply("kde", dims=dims, **kwargs)
156174

175+
def get_bins(self, dims=None, **kwargs):
176+
"""Compute the histogram bin edges for all variables in the dataset."""
177+
return self._apply(get_function("get_bins"), dims=dims, **kwargs)
178+
157179
def histogram(self, dims=None, **kwargs):
158-
"""Compute the KDE for all variables in the dataset."""
180+
"""Compute the histogram for all variables in the dataset."""
159181
return self._apply("histogram", dims=dims, **kwargs)
160182

161183
def compute_ranks(self, dims=None, relative=False):
@@ -219,25 +241,19 @@ def _process_input(self, group, method, allow_non_matching=True):
219241
f"and the DataTree itself is named {self._obs.name}"
220242
)
221243

222-
def _apply(self, fun_name, group, **kwargs):
244+
def _apply(self, func_name, group, **kwargs):
223245
hashable_group = False
224246
if isinstance(group, Hashable):
225247
group = [group]
226248
hashable_group = True
227249
out_dt = DataTree.from_dict(
228250
{
229-
group_i: xr.Dataset(
230-
{
231-
var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs))
232-
for var_name, da in self._process_input(
233-
# if group is a single str/hashable that doesn't match the group
234-
# name, still allow it and apply the function to the top level of
235-
# the provided datatree
236-
group_i,
237-
fun_name,
238-
allow_non_matching=hashable_group,
239-
).items()
240-
}
251+
group_i: apply_function_to_dataset(
252+
get_function(func_name),
253+
# if group is a single str/hashable that doesn't match the group name,
254+
# still allow it and apply the function to the top level of the provided input
255+
self._process_input(group_i, func_name, allow_non_matching=hashable_group),
256+
kwargs=kwargs,
241257
)
242258
for group_i in group
243259
}

src/arviz_stats/base/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def compute_ranks(self, ary, axes=-1, relative=False):
179179
)
180180
return compute_ranks_ufunc(ary, out_shape=(ary.shape[i] for i in axes), relative=relative)
181181

182-
def get_bins(self, ary, axes=-1):
182+
def get_bins(self, ary, axes=-1, bins="arviz"):
183183
"""Compute default bins."""
184184
ary, axes = process_ary_axes(ary, axes)
185185
get_bininfo_ufunc = make_ufunc(
@@ -188,10 +188,11 @@ def get_bins(self, ary, axes=-1):
188188
n_input=1,
189189
n_dims=len(axes),
190190
)
191-
x_min, x_max, width = get_bininfo_ufunc(ary)
191+
# TODO: improve handling of array_like bins
192+
x_min, x_max, width = get_bininfo_ufunc(ary, bins=bins)
192193
n_bins = np.ceil((x_max - x_min) / width)
193194
n_bins = np.ceil(np.mean(n_bins)).astype(int)
194-
return np.moveaxis(np.linspace(x_min, x_max, n_bins), 0, -1)
195+
return np.moveaxis(np.linspace(x_min, x_max, n_bins + 1), 0, -1)
195196

196197
# pylint: disable=redefined-builtin, too-many-return-statements
197198
# noqa: PLR0911

src/arviz_stats/base/core.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,28 @@ def _compute_ranks(self, ary, relative=False):
137137
return out / out.size
138138
return out
139139

140-
def _get_bininfo(self, values):
140+
def _get_bininfo(self, values, bins="arviz"):
141141
dtype = values.dtype.kind
142142

143+
if isinstance(bins, str) and bins != "arviz":
144+
bins = np.histogram_bin_edges(values, bins=bins)
145+
146+
if isinstance(bins, np.ndarray):
147+
return bins[0], bins[-1], bins[1] - bins[0]
148+
143149
if dtype == "i":
144150
x_min = values.min().astype(int)
145151
x_max = values.max().astype(int)
146152
else:
147153
x_min = values.min().astype(float)
148154
x_max = values.max().astype(float)
149155

156+
if isinstance(bins, int):
157+
width = (x_max - x_min) / bins
158+
if dtype == "i":
159+
width = max(1, width)
160+
return x_min, x_max, width
161+
150162
# Sturges histogram bin estimator
151163
width_sturges = (x_max - x_min) / (np.log2(values.size) + 1)
152164

@@ -161,13 +173,20 @@ def _get_bininfo(self, values):
161173

162174
return x_min, x_max, width
163175

164-
def _get_bins(self, values):
176+
def _get_bins(self, values, bins="arviz"):
165177
"""
166178
Automatically compute the number of bins for histograms.
167179
168180
Parameters
169181
----------
170-
values = array_like
182+
values : array_like
183+
bins : int, str or array_like, default "arviz"
184+
If `bins` "arviz", use ArviZ default rule (explained in detail in notes),
185+
if it is a different string it is passed to :func:`numpy.histogram_bin_edges`.
186+
If `bins` is an integer it is interpreted as the number of bins, however,
187+
if `values` holds discrete data, there is an extra check to prevent
188+
the width of the bins to be smaller than ``1``.
189+
If it is an array it is returned as it.
171190
172191
Returns
173192
-------
@@ -190,7 +209,7 @@ def _get_bins(self, values):
190209
"""
191210
dtype = values.dtype.kind
192211

193-
x_min, x_max, width = self._get_bininfo(values)
212+
x_min, x_max, width = self._get_bininfo(values, bins)
194213

195214
if dtype == "i":
196215
bins = np.arange(x_min, x_max + width + 1, width)

src/arviz_stats/base/dataarray.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,31 @@ def mcse(self, da, dims=None, method="mean", prob=None):
115115
},
116116
)
117117

118+
def get_bins(self, da, dims=None, bins="arviz"):
119+
"""Compute bins or align provided ones with DataArray input."""
120+
dims = validate_dims(dims)
121+
return apply_ufunc(
122+
self.array_class.get_bins,
123+
da,
124+
input_core_dims=[dims],
125+
output_core_dims=[["edges_dim" if da.name is None else f"edges_dim_{da.name}"]],
126+
kwargs={
127+
"bins": bins,
128+
"axes": np.arange(-len(dims), 0, 1),
129+
},
130+
)
131+
118132
# pylint: disable=redefined-builtin
119133
def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=None):
120134
"""Compute histogram on DataArray input."""
121135
dims = validate_dims(dims)
122-
edges_dim = "edges_dim"
123-
hist_dim = "hist_dim"
136+
edges_dim = "edges_dim" if da.name is None else f"edges_dim_{da.name}"
137+
hist_dim = "hist_dim" if da.name is None else f"hist_dim_{da.name}"
124138
input_core_dims = [dims]
125139
if isinstance(bins, DataArray):
126-
bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]]
127-
assert len(bins_dims) == 1
128140
if "plot_axis" in bins.dims:
141+
bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]]
142+
assert len(bins_dims) == 1
129143
hist_dim = bins_dims[0]
130144
bins = (
131145
concat(
@@ -138,8 +152,11 @@ def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=
138152
.rename({hist_dim: edges_dim})
139153
.drop_vars(edges_dim)
140154
)
141-
else:
142-
edges_dim = bins_dims[0]
155+
elif edges_dim not in bins.dims:
156+
raise ValueError(
157+
"Invalid 'bins' DataArray, it should contain either 'plot_axis' or "
158+
f"'{edges_dim}' dimension"
159+
)
143160
input_core_dims.append([edges_dim])
144161
else:
145162
input_core_dims.append([])

0 commit comments

Comments
 (0)