|
8 | 8 | from copy import copy |
9 | 9 |
|
10 | 10 | import numpy as np |
| 11 | + |
| 12 | +try: |
| 13 | + from ipywidgets import interactive, widgets |
| 14 | +except ImportError: |
| 15 | + pass |
11 | 16 | from scipy import stats |
12 | 17 |
|
13 | 18 | from .distributions_multivariate import Continuous |
14 | 19 | from .continuous import Beta, Normal |
15 | 20 | from ..internal.distribution_helper import all_not_none |
16 | 21 | from ..internal.plot_helper_multivariate import plot_dirichlet, plot_mvnormal |
17 | | - |
| 22 | +from ..internal.plot_helper import check_inside_notebook, get_slider |
18 | 23 |
|
19 | 24 | eps = np.finfo(float).eps |
20 | 25 |
|
@@ -63,7 +68,7 @@ def __init__(self, alpha=None): |
63 | 68 | self._parametrization(alpha) |
64 | 69 |
|
65 | 70 | def _parametrization(self, alpha=None): |
66 | | - self.param_names = "alpha" |
| 71 | + self.param_names = ("alpha",) |
67 | 72 | self.params_support = ((eps, np.inf),) |
68 | 73 |
|
69 | 74 | self.alpha = alpha |
@@ -205,6 +210,108 @@ def plot_ppf( |
205 | 210 | self, "ppf", "marginals", pointinterval, interval, levels, None, figsize, ax |
206 | 211 | ) |
207 | 212 |
|
| 213 | + def plot_interactive( |
| 214 | + self, |
| 215 | + kind="pdf", |
| 216 | + xy_lim="both", |
| 217 | + pointinterval=True, |
| 218 | + interval="hdi", |
| 219 | + levels=None, |
| 220 | + figsize=None, |
| 221 | + ): |
| 222 | + """ |
| 223 | + Interactive exploration of parameters |
| 224 | +
|
| 225 | + Parameters |
| 226 | + ---------- |
| 227 | + kind : str: |
| 228 | + Type of plot. Available options are `pdf`, `cdf` and `ppf`. |
| 229 | + xy_lim : str or tuple |
| 230 | + Set the limits of the x-axis and/or y-axis. |
| 231 | + Defaults to `"both"`, the limits of both axes are fixed for all subplots. |
| 232 | + Use `"auto"` for automatic rescaling of x-axis and y-axis. |
| 233 | + Or set them manually by passing a tuple of 4 elements, |
| 234 | + the first two for x-axis, the last two for y-axis. The tuple can have `None`. |
| 235 | + pointinterval : bool |
| 236 | + Whether to include a plot of the quantiles. Defaults to False. |
| 237 | + If `True` the default is to plot the median and two inter-quantiles ranges. |
| 238 | + interval : str |
| 239 | + Type of interval. Available options are the highest density interval `"hdi"` (default), |
| 240 | + equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`. |
| 241 | + levels : list |
| 242 | + Mass of the intervals. For hdi or eti the number of elements should be 2 or 1. |
| 243 | + For quantiles the number of elements should be 5, 3, 1 or 0 |
| 244 | + (in this last case nothing will be plotted). |
| 245 | + figsize : tuple |
| 246 | + Size of the figure |
| 247 | + """ |
| 248 | + |
| 249 | + check_inside_notebook() |
| 250 | + |
| 251 | + args = dict(zip(self.param_names, self.params)) |
| 252 | + self.__init__(**args) # pylint: disable=unnecessary-dunder-call |
| 253 | + if kind == "pdf": |
| 254 | + w_checkbox_marginals = widgets.Checkbox( |
| 255 | + value=True, |
| 256 | + description="marginals", |
| 257 | + disabled=False, |
| 258 | + indent=False, |
| 259 | + ) |
| 260 | + plot_widgets = {"marginals": w_checkbox_marginals} |
| 261 | + else: |
| 262 | + plot_widgets = {} |
| 263 | + for index, dim in enumerate(self.params[0]): |
| 264 | + plot_widgets[f"alpha-{index + 1}"] = get_slider( |
| 265 | + f"alpha-{index + 1}", dim, *self.params_support[0] |
| 266 | + ) |
| 267 | + |
| 268 | + def plot(**args): |
| 269 | + if kind == "pdf": |
| 270 | + marginals = args.pop("marginals") |
| 271 | + params = {"alpha": np.asarray(list(args.values()), dtype=float)} |
| 272 | + self.__init__(**params) # pylint: disable=unnecessary-dunder-call |
| 273 | + if kind == "pdf": |
| 274 | + plot_dirichlet( |
| 275 | + self, |
| 276 | + "pdf", |
| 277 | + marginals, |
| 278 | + pointinterval, |
| 279 | + interval, |
| 280 | + levels, |
| 281 | + "full", |
| 282 | + figsize, |
| 283 | + None, |
| 284 | + xy_lim, |
| 285 | + ) |
| 286 | + elif kind == "cdf": |
| 287 | + plot_dirichlet( |
| 288 | + self, |
| 289 | + "cdf", |
| 290 | + "marginals", |
| 291 | + pointinterval, |
| 292 | + interval, |
| 293 | + levels, |
| 294 | + "full", |
| 295 | + figsize, |
| 296 | + None, |
| 297 | + xy_lim, |
| 298 | + ) |
| 299 | + elif kind == "ppf": |
| 300 | + plot_dirichlet( |
| 301 | + self, |
| 302 | + "cdf", |
| 303 | + "marginals", |
| 304 | + pointinterval, |
| 305 | + interval, |
| 306 | + levels, |
| 307 | + None, |
| 308 | + figsize, |
| 309 | + None, |
| 310 | + xy_lim, |
| 311 | + ) |
| 312 | + |
| 313 | + return interactive(plot, **plot_widgets) |
| 314 | + |
208 | 315 |
|
209 | 316 | class MvNormal(Continuous): |
210 | 317 | r""" |
|
0 commit comments