Skip to content

Commit 9169807

Browse files
Add plot_interactive to Dirichlet (#332)
* Add xy_lim to plot_dirichlet * Added plot_interactive to Dirichlet * Tests for plot_interactive * Change x to y, and renamed dim to alpha
1 parent 1a102fa commit 9169807

File tree

3 files changed

+151
-6
lines changed

3 files changed

+151
-6
lines changed

preliz/distributions/continuous_multivariate.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
from copy import copy
99

1010
import numpy as np
11+
12+
try:
13+
from ipywidgets import interactive, widgets
14+
except ImportError:
15+
pass
1116
from scipy import stats
1217

1318
from .distributions_multivariate import Continuous
1419
from .continuous import Beta, Normal
1520
from ..internal.distribution_helper import all_not_none
1621
from ..internal.plot_helper_multivariate import plot_dirichlet, plot_mvnormal
17-
22+
from ..internal.plot_helper import check_inside_notebook, get_slider
1823

1924
eps = np.finfo(float).eps
2025

@@ -63,7 +68,7 @@ def __init__(self, alpha=None):
6368
self._parametrization(alpha)
6469

6570
def _parametrization(self, alpha=None):
66-
self.param_names = "alpha"
71+
self.param_names = ("alpha",)
6772
self.params_support = ((eps, np.inf),)
6873

6974
self.alpha = alpha
@@ -205,6 +210,108 @@ def plot_ppf(
205210
self, "ppf", "marginals", pointinterval, interval, levels, None, figsize, ax
206211
)
207212

213+
def plot_interactive(
214+
self,
215+
kind="pdf",
216+
xy_lim="both",
217+
pointinterval=True,
218+
interval="hdi",
219+
levels=None,
220+
figsize=None,
221+
):
222+
"""
223+
Interactive exploration of parameters
224+
225+
Parameters
226+
----------
227+
kind : str:
228+
Type of plot. Available options are `pdf`, `cdf` and `ppf`.
229+
xy_lim : str or tuple
230+
Set the limits of the x-axis and/or y-axis.
231+
Defaults to `"both"`, the limits of both axes are fixed for all subplots.
232+
Use `"auto"` for automatic rescaling of x-axis and y-axis.
233+
Or set them manually by passing a tuple of 4 elements,
234+
the first two for x-axis, the last two for y-axis. The tuple can have `None`.
235+
pointinterval : bool
236+
Whether to include a plot of the quantiles. Defaults to False.
237+
If `True` the default is to plot the median and two inter-quantiles ranges.
238+
interval : str
239+
Type of interval. Available options are the highest density interval `"hdi"` (default),
240+
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
241+
levels : list
242+
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
243+
For quantiles the number of elements should be 5, 3, 1 or 0
244+
(in this last case nothing will be plotted).
245+
figsize : tuple
246+
Size of the figure
247+
"""
248+
249+
check_inside_notebook()
250+
251+
args = dict(zip(self.param_names, self.params))
252+
self.__init__(**args) # pylint: disable=unnecessary-dunder-call
253+
if kind == "pdf":
254+
w_checkbox_marginals = widgets.Checkbox(
255+
value=True,
256+
description="marginals",
257+
disabled=False,
258+
indent=False,
259+
)
260+
plot_widgets = {"marginals": w_checkbox_marginals}
261+
else:
262+
plot_widgets = {}
263+
for index, dim in enumerate(self.params[0]):
264+
plot_widgets[f"alpha-{index + 1}"] = get_slider(
265+
f"alpha-{index + 1}", dim, *self.params_support[0]
266+
)
267+
268+
def plot(**args):
269+
if kind == "pdf":
270+
marginals = args.pop("marginals")
271+
params = {"alpha": np.asarray(list(args.values()), dtype=float)}
272+
self.__init__(**params) # pylint: disable=unnecessary-dunder-call
273+
if kind == "pdf":
274+
plot_dirichlet(
275+
self,
276+
"pdf",
277+
marginals,
278+
pointinterval,
279+
interval,
280+
levels,
281+
"full",
282+
figsize,
283+
None,
284+
xy_lim,
285+
)
286+
elif kind == "cdf":
287+
plot_dirichlet(
288+
self,
289+
"cdf",
290+
"marginals",
291+
pointinterval,
292+
interval,
293+
levels,
294+
"full",
295+
figsize,
296+
None,
297+
xy_lim,
298+
)
299+
elif kind == "ppf":
300+
plot_dirichlet(
301+
self,
302+
"cdf",
303+
"marginals",
304+
pointinterval,
305+
interval,
306+
levels,
307+
None,
308+
figsize,
309+
None,
310+
xy_lim,
311+
)
312+
313+
return interactive(plot, **plot_widgets)
314+
208315

209316
class MvNormal(Continuous):
210317
r"""

preliz/internal/plot_helper_multivariate.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def plot_dirichlet(
9595
support,
9696
figsize,
9797
axes,
98+
xy_lim="auto",
9899
):
99100
"""Plot pdf, cdf or ppf of Dirichlet distribution."""
100101

@@ -104,6 +105,10 @@ def plot_dirichlet(
104105
if figsize is None:
105106
figsize = (12, 4)
106107

108+
if isinstance(xy_lim, tuple):
109+
xlim = xy_lim[:2]
110+
ylim = xy_lim[2:]
111+
107112
if marginals:
108113
a_0 = alpha.sum()
109114
cols, rows = get_cols_rows(dim)
@@ -116,8 +121,18 @@ def plot_dirichlet(
116121
ax.remove()
117122

118123
for a_i, ax in zip(alpha, axes):
124+
marginal_dist = dist.marginal(a_i, a_0 - a_i)
125+
if xy_lim == "both":
126+
xlim = marginal_dist._finite_endpoints("full")
127+
xvals = marginal_dist.xvals("restricted")
128+
if representation == "pdf":
129+
max_pdf = np.max(marginal_dist.pdf(xvals))
130+
ylim = (-max_pdf * 0.075, max_pdf * 1.5)
131+
elif representation == "ppf":
132+
max_ppf = marginal_dist.ppf(0.999)
133+
ylim = (-max_ppf * 0.075, max_ppf * 1.5)
119134
if representation == "pdf":
120-
dist.marginal(a_i, a_0 - a_i).plot_pdf(
135+
marginal_dist.plot_pdf(
121136
pointinterval=pointinterval,
122137
interval=interval,
123138
levels=levels,
@@ -126,7 +141,7 @@ def plot_dirichlet(
126141
ax=ax,
127142
)
128143
elif representation == "cdf":
129-
dist.marginal(a_i, a_0 - a_i).plot_cdf(
144+
marginal_dist.plot_cdf(
130145
pointinterval=pointinterval,
131146
interval=interval,
132147
levels=levels,
@@ -135,14 +150,17 @@ def plot_dirichlet(
135150
ax=ax,
136151
)
137152
elif representation == "ppf":
138-
dist.marginal(a_i, a_0 - a_i).plot_ppf(
153+
marginal_dist.plot_ppf(
139154
pointinterval=pointinterval,
140155
interval=interval,
141156
levels=levels,
142157
legend=False,
143158
ax=ax,
144159
)
145-
160+
if xy_lim != "auto" and representation != "ppf":
161+
ax.set_xlim(*xlim)
162+
if xy_lim != "auto" and representation != "cdf":
163+
ax.set_ylim(*ylim)
146164
fig.text(0.5, 1, repr_to_matplotlib(dist), ha="center", va="center")
147165

148166
else:

preliz/tests/test_plots.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,26 @@ def test_dirichlet_plot(kwargs):
6969
a_dist.plot_ppf(**kwargs)
7070

7171

72+
@pytest.mark.parametrize(
73+
"kwargs",
74+
[
75+
{},
76+
{"xy_lim": "auto"},
77+
{"pointinterval": True, "xy_lim": "auto"},
78+
{"pointinterval": True, "levels": [0.1, 0.9], "xy_lim": "both"},
79+
{"pointinterval": True, "interval": "eti", "levels": [0.9], "xy_lim": (0.3, 0.9, 0.6, 1)},
80+
{"pointinterval": True, "interval": "quantiles", "xy_lim": "both"},
81+
{"pointinterval": True, "interval": "quantiles", "levels": [0.1, 0.5, 0.9]},
82+
{"pointinterval": False, "figsize": (4, 4)},
83+
],
84+
)
85+
def test_plot_interactive_dirichlet(kwargs):
86+
a_dist = pz.Dirichlet([2, 1, 2])
87+
a_dist.plot_interactive(kind="pdf", **kwargs)
88+
a_dist.plot_interactive(kind="cdf", **kwargs)
89+
a_dist.plot_interactive(kind="ppf", **kwargs)
90+
91+
7292
@pytest.mark.parametrize(
7393
"kwargs",
7494
[

0 commit comments

Comments
 (0)