Skip to content

Commit c5a4f5f

Browse files
sethaxenOriolAbril
andauthored
Improvements to multimodal HDI (#28)
* Fix typo * Refactor multimodal HDI code More modular functions and vectorization * Ensure interval contains >=hdi_prob * For integer/bool HDI, default to bin width of 1 * Split continuous and discrete multimodal HDI * Default to ISJ bandwidth for multimodal HDI * Return highest probability modes * Fix bugs in circular KDE * Support circular continuous multimodal HDI * Assume input probabilities sum to 1 * Merge lines * Scale KDE density to bin probabilities * Use bins returned by `_histogram` * Avoid duplication of HDI defaults * Fix and test passing bins to discrete multimodal * Simplify HDI nearest code * Fix circular standardization * Correctly compute bin centers * Fix pylint issues * Move interval splitting to own function * Use circular standardization * Add method for computing HDI from point densities * Add multimodal_nearest HDI method * rename and add check for warning in tests --------- Co-authored-by: Oriol (ProDesk) <[email protected]>
1 parent 5b110df commit c5a4f5f

File tree

4 files changed

+255
-76
lines changed

4 files changed

+255
-76
lines changed

src/arviz_stats/base/array.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,24 @@ def hdi(
4747
circular=False,
4848
max_modes=10,
4949
skipna=False,
50+
**kwargs,
5051
):
51-
"""Compute of HDI function on array-like input."""
52+
"""Compute HDI function on array-like input."""
5253
if not 1 >= prob > 0:
5354
raise ValueError("The value of `prob` must be in the (0, 1] interval.")
54-
if method == "multimodal" and circular:
55-
raise ValueError("Multimodal hdi not supported for circular data.")
5655
ary, axes = process_ary_axes(ary, axes)
56+
is_discrete = np.issubdtype(ary.dtype, np.integer) or np.issubdtype(ary.dtype, np.bool_)
57+
is_multimodal = method.startswith("multimodal")
58+
if is_multimodal and circular and is_discrete:
59+
raise ValueError("Multimodal hdi not supported for discrete circular data.")
5760
hdi_func = {
5861
"nearest": self._hdi_nearest,
59-
"multimodal": self._hdi_multimodal,
62+
"multimodal": (
63+
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
64+
),
65+
"multimodal_sample": (
66+
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
67+
),
6068
}[method]
6169
hdi_array = make_ufunc(
6270
hdi_func,
@@ -67,15 +75,23 @@ def hdi(
6775
func_kwargs = {
6876
"prob": prob,
6977
"skipna": skipna,
70-
"out_shape": (max_modes, 2) if method == "multimodal" else (2,),
78+
"out_shape": (max_modes, 2) if is_multimodal else (2,),
79+
"circular": circular,
7180
}
72-
if method != "multimodal":
73-
func_kwargs["circular"] = circular
74-
else:
81+
if is_multimodal:
7582
func_kwargs["max_modes"] = max_modes
83+
if is_discrete:
84+
func_kwargs.pop("circular")
85+
func_kwargs.pop("skipna")
86+
else:
87+
func_kwargs["bw"] = "isj" if not circular else "taylor"
88+
func_kwargs.update(kwargs)
89+
90+
if method == "multimodal_sample":
91+
func_kwargs["from_sample"] = True
7692

7793
result = hdi_array(ary, **func_kwargs)
78-
if method == "multimodal":
94+
if is_multimodal:
7995
mode_mask = [np.all(np.isnan(result[..., i, :])) for i in range(result.shape[-2])]
8096
result = result[..., ~np.array(mode_mask), :]
8197
return result

src/arviz_stats/base/core.py

Lines changed: 133 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def circular_mean(self, ary): # pylint: disable=no-self-use
8686
"""
8787
return circmean(ary, high=np.pi, low=-np.pi)
8888

89+
def _circular_standardize(self, ary): # pylint: disable=no-self-use
90+
"""Standardize circular data to the interval [-pi, pi]."""
91+
return np.mod(ary + np.pi, 2 * np.pi) - np.pi
92+
8993
def quantile(self, ary, quantile, **kwargs): # pylint: disable=no-self-use
9094
"""Compute the quantile of an array of samples.
9195
@@ -226,20 +230,9 @@ def _histogram(self, ary, bins=None, range=None, weights=None, density=None):
226230
bins = self._get_bins(ary)
227231
return np.histogram(ary, bins=bins, range=range, weights=weights, density=density)
228232

229-
def _hdi_linear_nearest_common(self, ary, prob, skipna, circular):
230-
ary = ary.flatten()
231-
if skipna:
232-
nans = np.isnan(ary)
233-
if not nans.all():
234-
ary = ary[~nans]
233+
def _hdi_linear_nearest_common(self, ary, prob): # pylint: disable=no-self-use
235234
n = len(ary)
236235

237-
mean = None
238-
if circular:
239-
mean = self.circular_mean(ary)
240-
ary = ary - mean
241-
ary = np.arctan2(np.sin(ary), np.cos(ary))
242-
243236
ary = np.sort(ary)
244237
interval_idx_inc = int(np.floor(prob * n))
245238
n_intervals = n - interval_idx_inc
@@ -249,62 +242,147 @@ def _hdi_linear_nearest_common(self, ary, prob, skipna, circular):
249242
raise ValueError("Too few elements for interval calculation. ")
250243

251244
min_idx = np.argmin(interval_width)
245+
hdi_interval = ary[[min_idx, min_idx + interval_idx_inc]]
252246

253-
return ary, mean, min_idx, interval_idx_inc
247+
return hdi_interval
254248

255249
def _hdi_nearest(self, ary, prob, circular, skipna):
256250
"""Compute HDI over the flattened array as closest samples that contain the given prob."""
257-
ary, mean, min_idx, interval_idx_inc = self._hdi_linear_nearest_common(
258-
ary, prob, skipna, circular
259-
)
260-
261-
hdi_min = ary[min_idx]
262-
hdi_max = ary[min_idx + interval_idx_inc]
251+
ary = ary.flatten()
252+
if skipna:
253+
nans = np.isnan(ary)
254+
if not nans.all():
255+
ary = ary[~nans]
263256

264257
if circular:
265-
hdi_min = hdi_min + mean
266-
hdi_max = hdi_max + mean
267-
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
268-
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
258+
mean = self.circular_mean(ary)
259+
ary = self._circular_standardize(ary - mean)
260+
261+
hdi_interval = self._hdi_linear_nearest_common(ary, prob)
269262

270-
hdi_interval = np.array([hdi_min, hdi_max])
263+
if circular:
264+
hdi_interval = self._circular_standardize(hdi_interval + mean)
271265

272266
return hdi_interval
273267

274-
def _hdi_multimodal(self, ary, prob, skipna, max_modes):
268+
def _hdi_multimodal_continuous(
269+
self, ary, prob, skipna, max_modes, circular, from_sample=False, **kwargs
270+
):
275271
"""Compute HDI if the distribution is multimodal."""
276272
ary = ary.flatten()
277273
if skipna:
278274
ary = ary[~np.isnan(ary)]
279275

280-
if ary.dtype.kind == "f":
281-
bins, density, _ = self.kde(ary)
282-
lower, upper = bins[0], bins[-1]
283-
range_x = upper - lower
284-
dx = range_x / len(density)
276+
bins, density, _ = self.kde(ary, circular=circular, **kwargs)
277+
if from_sample:
278+
ary_density = np.interp(ary, bins, density)
279+
hdi_intervals, interval_probs = self._hdi_from_point_densities(
280+
ary, ary_density, prob, circular
281+
)
285282
else:
286-
bins = self._get_bins(ary)
287-
density, _ = self._histogram(ary, bins=bins, density=True)
288-
dx = np.diff(bins)[0]
289-
290-
density *= dx
291-
292-
idx = np.argsort(-density)
293-
intervals = bins[idx][density[idx].cumsum() <= prob]
294-
intervals.sort()
295-
296-
intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
297-
298-
hdi_intervals = np.full((max_modes, 2), np.nan)
299-
for i, interval in enumerate(intervals_splitted):
300-
if i == max_modes:
301-
warnings.warn(
302-
f"found more modes than {max_modes}, returning only the first {max_modes} modes"
303-
)
304-
break
305-
if interval.size == 0:
306-
hdi_intervals[i] = np.asarray([bins[0], bins[0]])
307-
else:
308-
hdi_intervals[i] = np.asarray([interval[0], interval[-1]])
309-
310-
return np.array(hdi_intervals)
283+
dx = (bins[-1] - bins[0]) / (len(bins) - 1)
284+
bin_probs = density * dx
285+
286+
hdi_intervals, interval_probs = self._hdi_from_bin_probabilities(
287+
bins, bin_probs, prob, circular, dx
288+
)
289+
290+
return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes)
291+
292+
def _hdi_multimodal_discrete(self, ary, prob, max_modes, bins=None):
293+
"""Compute HDI if the distribution is multimodal."""
294+
ary = ary.flatten()
295+
296+
if bins is None:
297+
bins, counts = np.unique(ary, return_counts=True)
298+
bin_probs = counts / len(ary)
299+
dx = 1
300+
else:
301+
counts, edges = self._histogram(ary, bins=bins)
302+
bins = 0.5 * (edges[1:] + edges[:-1])
303+
bin_probs = counts / counts.sum()
304+
dx = bins[1] - bins[0]
305+
306+
hdi_intervals, interval_probs = self._hdi_from_bin_probabilities(
307+
bins, bin_probs, prob, False, dx
308+
)
309+
310+
return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes)
311+
312+
def _hdi_from_point_densities(self, points, densities, prob, circular):
313+
if circular:
314+
points = self._circular_standardize(points)
315+
316+
sorted_idx = np.argsort(points)
317+
points = points[sorted_idx]
318+
densities = densities[sorted_idx]
319+
320+
# find idx of points in the interval
321+
interval_size = int(np.ceil(prob * len(points)))
322+
sorted_idx = np.argsort(densities)[::-1]
323+
idx_in_interval = sorted_idx[:interval_size]
324+
idx_in_interval.sort()
325+
326+
# find idx of interval bounds
327+
probs_in_interval = np.full(idx_in_interval.shape, 1 / len(points))
328+
interval_bounds_idx, interval_probs = self._interval_points_to_bounds(
329+
idx_in_interval, probs_in_interval, 1, circular, period=len(points)
330+
)
331+
332+
return points[interval_bounds_idx], interval_probs
333+
334+
def _hdi_from_bin_probabilities(self, bins, bin_probs, prob, circular, dx):
335+
if circular:
336+
bins = self._circular_standardize(bins)
337+
sorted_idx = np.argsort(bins)
338+
bins = bins[sorted_idx]
339+
bin_probs = bin_probs[sorted_idx]
340+
341+
# find idx of bins in the interval
342+
sorted_idx = np.argsort(bin_probs)[::-1]
343+
cum_probs = bin_probs[sorted_idx].cumsum()
344+
interval_size = np.searchsorted(cum_probs, prob, side="left") + 1
345+
idx_in_interval = sorted_idx[:interval_size]
346+
idx_in_interval.sort()
347+
348+
# get points in intervals
349+
intervals = bins[idx_in_interval]
350+
probs_in_interval = bin_probs[idx_in_interval]
351+
352+
return self._interval_points_to_bounds(intervals, probs_in_interval, dx, circular)
353+
354+
def _interval_points_to_bounds(self, points, probs, dx, circular, period=2 * np.pi): # pylint: disable=no-self-use
355+
cum_probs = probs.cumsum()
356+
357+
is_bound = np.diff(points) > dx * 1.01
358+
is_lower_bound = np.insert(is_bound, 0, True)
359+
is_upper_bound = np.append(is_bound, True)
360+
interval_bounds = np.column_stack([points[is_lower_bound], points[is_upper_bound]])
361+
interval_probs = (
362+
cum_probs[is_upper_bound] - cum_probs[is_lower_bound] + probs[is_lower_bound]
363+
)
364+
365+
if (
366+
circular
367+
and np.mod(dx * 1.01 + interval_bounds[-1, -1] - interval_bounds[0, 0], period)
368+
<= dx * 1.01
369+
):
370+
interval_bounds[-1, 1] = interval_bounds[0, 1]
371+
interval_bounds = interval_bounds[1:, :]
372+
interval_probs[-1] += interval_probs[0]
373+
interval_probs = interval_probs[1:]
374+
375+
return interval_bounds, interval_probs
376+
377+
def _pad_hdi_to_maxmodes(self, hdi_intervals, interval_probs, max_modes): # pylint: disable=no-self-use
378+
if hdi_intervals.shape[0] > max_modes:
379+
warnings.warn(
380+
f"found more modes than {max_modes}, returning only the {max_modes} highest "
381+
"probability modes"
382+
)
383+
hdi_intervals = hdi_intervals[np.argsort(interval_probs)[::-1][:max_modes], :]
384+
elif hdi_intervals.shape[0] < max_modes:
385+
hdi_intervals = np.vstack(
386+
[hdi_intervals, np.full((max_modes - hdi_intervals.shape[0], 2), np.nan)]
387+
)
388+
return hdi_intervals

src/arviz_stats/base/dataarray.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def eti(self, da, prob=None, dims=None, method="linear"):
3434
kwargs={"axis": np.arange(-len(dims), 0, 1), "method": method},
3535
)
3636

37-
def hdi(
38-
self, da, prob=None, dims=None, method="nearest", circular=False, max_modes=10, skipna=False
39-
):
37+
def hdi(self, da, prob=None, dims=None, method="nearest", **kwargs):
4038
"""Compute hdi on DataArray input."""
4139
dims = validate_dims(dims)
4240
prob = validate_ci_prob(prob)
@@ -48,13 +46,11 @@ def hdi(
4846
da,
4947
prob,
5048
input_core_dims=[dims, []],
51-
output_core_dims=[[mode_dim, "hdi"] if method == "multimodal" else ["hdi"]],
49+
output_core_dims=[[mode_dim, "hdi"] if method.startswith("multimodal") else ["hdi"]],
5250
kwargs={
53-
"method": method,
54-
"circular": circular,
55-
"skipna": skipna,
56-
"max_modes": max_modes,
5751
"axes": np.arange(-len(dims), 0, 1),
52+
"method": method,
53+
**kwargs,
5854
},
5955
).assign_coords({"hdi": hdi_coord})
6056

0 commit comments

Comments
 (0)