Skip to content

Commit 1a499b6

Browse files
Add more arguments to psense_summary (#34)
* add more arguments to psense_summary * rST and api formatting * use alphas instead of delta * fix test due to small numerical difference in lower_alpha computation * fix bug delta --------- Co-authored-by: Oriol Abril-Pla <[email protected]>
1 parent c5a4f5f commit 1a499b6

File tree

8 files changed

+128
-40
lines changed

8 files changed

+128
-40
lines changed

docs/source/api/index.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# API reference
22

3+
## Functions
4+
5+
```{eval-rst}
6+
.. autosummary::
7+
:toctree: generated/
8+
9+
arviz_stats.psense
10+
arviz_stats.psense_summary
11+
```
12+
313
## Accessors
414
Currently, using accessors is the recommended way to call functions from `arviz_stats`.
515

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
"sphinx_design",
4141
"jupyter_sphinx",
4242
"sphinx_autosummary_accessors",
43+
"IPython.sphinxext.ipython_directive",
44+
"IPython.sphinxext.ipython_console_highlighting",
4345
]
4446

4547
templates_path = ["_templates", sphinx_autosummary_accessors.templates_path]

src/arviz_stats/accessors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,17 @@ def power_scale_lw(self, alpha=1, dims=None):
9696
"""Compute log weights for power-scaling of the DataTree."""
9797
return get_function("power_scale_lw")(self._obj, alpha=alpha, dims=dims)
9898

99-
def power_scale_sense(self, lower_w=None, upper_w=None, delta=None, dims=None):
99+
def power_scale_sense(
100+
self, lower_w=None, upper_w=None, lower_alpha=None, upper_alpha=None, dims=None
101+
):
100102
"""Compute power-scaling sensitivity."""
101103
return get_function("power_scale_sense")(
102-
self._obj, lower_w=lower_w, upper_w=upper_w, delta=delta, dims=dims
104+
self._obj,
105+
lower_w=lower_w,
106+
upper_w=upper_w,
107+
lower_alpha=lower_alpha,
108+
upper_alpha=upper_alpha,
109+
dims=dims,
103110
)
104111

105112

src/arviz_stats/base/array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def power_scale_lw(self, ary, alpha=0, axes=-1):
168168
)
169169
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)
170170

171-
def power_scale_sense(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
171+
def power_scale_sense(
172+
self, ary, lower_w, upper_w, lower_alpha, upper_alpha, chain_axis=-2, draw_axis=-1
173+
):
172174
"""Compute power-scaling sensitivity."""
173175
if chain_axis is None:
174176
ary = np.expand_dims(ary, axis=0)
@@ -181,7 +183,7 @@ def power_scale_sense(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_ax
181183
pss_array = make_ufunc(
182184
self._power_scale_sense, n_output=1, n_input=3, n_dims=2, ravel=False
183185
)
184-
return pss_array(ary, lower_w, upper_w, delta=delta)
186+
return pss_array(ary, lower_w, upper_w, lower_alpha=lower_alpha, upper_alpha=upper_alpha)
185187

186188
def compute_ranks(self, ary, axes=-1, relative=False):
187189
"""Compute ranks of MCMC samples."""

src/arviz_stats/base/dataarray.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,15 @@ def power_scale_lw(self, da, alpha=0, dims=None):
299299
kwargs={"axes": np.arange(-len(dims), 0, 1)},
300300
)
301301

302-
def power_scale_sense(self, da, lower_w, upper_w, delta, dims=None):
302+
def power_scale_sense(self, da, lower_w, upper_w, lower_alpha, upper_alpha, dims=None):
303303
"""Compute power-scaling sensitivity."""
304304
dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(dims)
305305
return apply_ufunc(
306306
self.array_class.power_scale_sense,
307307
*broadcast(da, lower_w, upper_w),
308-
delta,
309-
input_core_dims=[dims, dims, dims, []],
308+
lower_alpha,
309+
upper_alpha,
310+
input_core_dims=[dims, dims, dims, [], []],
310311
output_core_dims=[[]],
311312
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
312313
)

src/arviz_stats/base/diagnostics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,15 +551,16 @@ def _gpinv(probs, kappa, sigma, mu):
551551

552552
return q
553553

554-
def _power_scale_sense(self, ary, lower_w, upper_w, delta=0.01):
554+
def _power_scale_sense(self, ary, lower_w, upper_w, lower_alpha, upper_alpha):
555555
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
556556
ary = np.ravel(ary)
557557
lower_w = np.ravel(lower_w)
558558
upper_w = np.ravel(upper_w)
559559
lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w))
560560
upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w))
561-
grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta))
562-
return grad
561+
lower_grad = -1 * lower_cjs / np.log2(lower_alpha)
562+
upper_grad = upper_cjs / np.log2(upper_alpha)
563+
return (lower_grad + upper_grad) / 2
563564

564565
def _power_scale_lw(self, ary, alpha):
565566
"""Compute log weights for power-scaling component by alpha."""

src/arviz_stats/psense.py

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
def psense(
2121
dt,
22+
var_names=None,
23+
filter_vars=None,
2224
group="prior",
25+
coords=None,
2326
sample_dims=None,
27+
alphas=(0.99, 1.01),
2428
group_var_names=None,
2529
group_coords=None,
26-
var_names=None,
27-
coords=None,
28-
filter_vars=None,
29-
delta=0.01,
3030
):
3131
"""
3232
Compute power-scaling sensitivity values.
@@ -38,28 +38,30 @@ def psense(
3838
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
3939
For ndarray: shape = (chain, draw).
4040
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
41+
var_names : list of str, optional
42+
Names of posterior variables to include in the power scaling sensitivity diagnostic
43+
filter_vars: {None, "like", "regex"}, default None
44+
Used for `var_names` only.
45+
If ``None`` (default), interpret var_names as the real variables names.
46+
If "like", interpret var_names as substrings of the real variables names.
47+
If "regex", interpret var_names as regular expressions on the real variables names.
4148
group : {"prior", "likelihood"}, default "prior"
4249
If "likelihood", the pointsize log likelihood values are retrieved
4350
from the ``log_likelihood`` group and added together.
4451
If "prior", the log prior values are retrieved from the ``log_prior`` group.
52+
coords : dict, optional
53+
Coordinates defining a subset over the posterior. Only these variables will
54+
be used when computing the prior sensitivity.
55+
sample_dims : str or sequence of hashable, optional
56+
Dimensions to reduce unless mapped to an aesthetic.
57+
Defaults to ``rcParams["data.sample_dims"]``
58+
alphas : tuple
59+
Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
4560
group_var_names : str, optional
4661
Name of the prior or log likelihood variables to use
4762
group_coords : dict, optional
4863
Coordinates defining a subset over the group element for which to
4964
compute the prior sensitivity diagnostic.
50-
var_names : list of str, optional
51-
Names of posterior variables to include in the power scaling sensitivity diagnostic
52-
coords : dict, optional
53-
Coordinates defining a subset over the posterior. Only these variables will
54-
be used when computing the prior sensitivity.
55-
filter_vars: {None, "like", "regex"}, default None
56-
Used for `var_names` only.
57-
If ``None`` (default), interpret var_names as the real variables names.
58-
If "like", interpret var_names as substrings of the real variables names.
59-
If "regex", interpret var_names as regular expressions on the real variables names.
60-
delta : float
61-
Value for finite difference derivative calculation.
62-
6365
6466
Returns
6567
-------
@@ -78,20 +80,22 @@ def psense(
7880
References
7981
----------
8082
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
81-
power-scaling*, 2022, https://arxiv.org/abs/2107.14054
83+
power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
8284
"""
8385
dataset = extract(
84-
dt, var_names=var_names, filter_vars=filter_vars, group="posterior", combined=False
86+
dt,
87+
var_names=var_names,
88+
filter_vars=filter_vars,
89+
group="posterior",
90+
combined=False,
91+
keep_dataset=True,
8592
)
8693
if coords is not None:
8794
dataset = dataset.sel(coords)
8895

89-
lower_alpha = 1 / (1 + delta)
90-
upper_alpha = 1 + delta
91-
9296
lower_w, upper_w = _get_power_scale_weights(
9397
dt,
94-
alphas=(lower_alpha, upper_alpha),
98+
alphas=alphas,
9599
group=group,
96100
sample_dims=sample_dims,
97101
group_var_names=group_var_names,
@@ -101,20 +105,52 @@ def psense(
101105
return dataset.azstats.power_scale_sense(
102106
lower_w=lower_w,
103107
upper_w=upper_w,
104-
delta=delta,
108+
lower_alpha=alphas[0],
109+
upper_alpha=alphas[1],
105110
dims=sample_dims,
106111
)
107112

108113

109-
def psense_summary(data, threshold=0.05, round_to=3):
114+
def psense_summary(
115+
data,
116+
var_names=None,
117+
filter_vars=None,
118+
coords=None,
119+
sample_dims=None,
120+
threshold=0.05,
121+
alphas=(0.99, 1.01),
122+
group_var_names=None,
123+
group_coords=None,
124+
round_to=3,
125+
):
110126
"""
111127
Compute the prior/likelihood sensitivity based on power-scaling perturbations.
112128
113129
Parameters
114130
----------
115131
data : DataTree
132+
var_names : list of str, optional
133+
Names of posterior variables to include in the power scaling sensitivity diagnostic
134+
filter_vars: {None, "like", "regex"}, default None
135+
Used for `var_names` only.
136+
If ``None`` (default), interpret var_names as the real variables names.
137+
If "like", interpret var_names as substrings of the real variables names.
138+
If "regex", interpret var_names as regular expressions on the real variables names.
139+
coords : dict, optional
140+
Coordinates defining a subset over the posterior. Only these variables will
141+
be used when computing the prior sensitivity.
142+
sample_dims : str or sequence of hashable, optional
143+
Dimensions to reduce unless mapped to an aesthetic.
144+
Defaults to ``rcParams["data.sample_dims"]``
116145
threshold : float, optional
117146
Threshold value to determine the sensitivity diagnosis. Default is 0.05.
147+
alphas : tuple
148+
Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
149+
group_var_names : str, optional
150+
Name of the prior or log likelihood variables to use
151+
group_coords : dict, optional
152+
Coordinates defining a subset over the group element for which to
153+
compute the prior sensitivity diagnostic
118154
round_to : int, optional
119155
Number of decimal places to round the sensitivity values. Default is 3.
120156
@@ -127,9 +163,38 @@ def psense_summary(data, threshold=0.05, round_to=3):
127163
- "strong prior / weak likelihood" if the prior sensitivity is above threshold
128164
and the likelihood sensitivity is below the threshold
129165
- "-" otherwise
166+
167+
Examples
168+
--------
169+
.. ipython::
170+
171+
In [1]: from arviz_base import load_arviz_data
172+
...: from arviz_stats import psense_summary
173+
...: rugby = load_arviz_data("rugby")
174+
...: psense_summary(rugby, var_names="atts")
130175
"""
131-
pssdp = psense(data, group="prior")
132-
pssdl = psense(data, group="likelihood")
176+
pssdp = psense(
177+
data,
178+
var_names=var_names,
179+
filter_vars=filter_vars,
180+
group="prior",
181+
sample_dims=sample_dims,
182+
coords=coords,
183+
alphas=alphas,
184+
group_var_names=group_var_names,
185+
group_coords=group_coords,
186+
)
187+
pssdl = psense(
188+
data,
189+
var_names=var_names,
190+
filter_vars=filter_vars,
191+
group="likelihood",
192+
coords=coords,
193+
sample_dims=sample_dims,
194+
alphas=alphas,
195+
group_var_names=group_var_names,
196+
group_coords=group_coords,
197+
)
133198

134199
joined = xr.concat([pssdp, pssdl], dim="component").assign_coords(
135200
component=["prior", "likelihood"]

tests/test_psense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def test_psense_var_names():
1919
result_0 = psense(uni_dt, group="prior", group_var_names=["mu"], var_names=["mu"])
2020
result_1 = psense(uni_dt, group="prior", var_names=["mu"])
2121
for result in (result_0, result_1):
22-
assert "sigma" != result.name
23-
assert "mu" == result.name
24-
assert not isclose(result_0, result_1)
22+
assert "sigma" not in result.data_vars
23+
assert "mu" in result.data_vars
24+
assert not isclose(result_0["mu"], result_1["mu"])
2525

2626

2727
def test_psense_summary():

0 commit comments

Comments
 (0)