Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new Energy Score estimator, akr_kband, implementing an AKR-based k-band approximation across both the core (backend-agnostic) implementations and the numba gufunc dispatch path, and wires it through the public sr.es_ensemble API.
Changes:
- Add
akr_kbandestimator support in core Energy Score implementations (weighted + unweighted). - Add numba gufunc implementations for
akr_kband(weighted + unweighted) and register them in estimator maps. - Update public
sr.es_ensembleto passkthrough (currently via**kwargs) and addakr_kbandto the energy test parametrization.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_energy.py |
Adds akr_kband to the estimator test matrix and threads k through exception checks. |
scoringrules/core/energy/_score.py |
Adds akr_kband to core unweighted ES dispatch + implementation. |
scoringrules/core/energy/_score_w.py |
Adds akr_kband to core weighted ES dispatch + implementation. |
scoringrules/core/energy/_gufuncs.py |
Adds numba gufunc for unweighted akr_kband and registers it. |
scoringrules/core/energy/_gufuncs_w.py |
Adds numba gufunc for weighted akr_kband and registers it. |
scoringrules/_energy.py |
Updates public API to accept extra kwargs and route k to the appropriate implementation. |
Comments suppressed due to low confidence (1)
tests/test_energy.py:14
akr_kbandwas added toESTIMATORS, but the test’s later “correctness” assertions only cover nrg/fair/akr/akr_circperm. As a result,akr_kbandgets no numeric regression coverage. Consider extending the correctness checks (and weighted correctness checks) to include an expected value forakr_kband(for a fixedk, and ideally also a small check that non-defaultkis honored).
ESTIMATORS = ["nrg", "fair", "akr", "akr_circperm", "akr_kband"]
@pytest.mark.parametrize("estimator", ESTIMATORS)
def test_energy_score(estimator, backend):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | ||
| spread_norm = B.norm(fct - fct_shift, -1) | ||
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) |
There was a problem hiding this comment.
_es_ensemble_akr_kband_w weights the spread term by ens_w * ens_w_shift, which makes the weighted score disagree with the unweighted score even for uniform weights (unlike the other estimators in this module). This will break the existing invariant tested in tests/test_energy.py that ens_w=np.ones(...) matches the unweighted result. Consider aligning with _es_ensemble_akr_w / _es_ensemble_akr_circperm_w by weighting the k-band spread term with ens_w only (or otherwise adjust the normalization so uniform weights reproduce the unweighted akr_kband result).
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| # Weight the spread term with ens_w only, consistent with other AKR estimators. | |
| E_2 += 2 * B.sum(spread_norm * ens_w, -1) |
| 2 | ||
| * float(np.linalg.norm(fct[i] - fct[(i + j) % M])) | ||
| * ens_w[i] | ||
| * ens_w[(i + j) % M] |
There was a problem hiding this comment.
The new _energy_score_akr_kband_gufunc_w multiplies by ens_w[i] * ens_w[(i+j)%M], but the existing weighted AKR/circperm estimators in this file weight only by ens_w[i]. With ens_w uniform, this changes the scale by ~1/M compared to the unweighted akr_kband, and will fail the test that uniform weights reproduce the unweighted score. Either adjust the weighting (e.g., weight only by ens_w[i] like the other AKR variants) or adjust the unweighted implementation/normalization so the two agree for uniform weights.
| * ens_w[(i + j) % M] |
| def es_ensemble( | ||
| obs: "Array", | ||
| fct: "Array", | ||
| m_axis: int = -2, | ||
| v_axis: int = -1, | ||
| *, | ||
| ens_w: "Array" = None, | ||
| estimator: str = "nrg", | ||
| backend: "Backend" = None, | ||
| **kwargs, | ||
| ) -> "Array": |
There was a problem hiding this comment.
es_ensemble now accepts arbitrary **kwargs, but the docstring doesn't document any additional keyword arguments and (for non-akr_kband estimators) extra kwargs are silently ignored. This makes it easy for user typos to go unnoticed. Prefer adding an explicit keyword-only k: int = 1 parameter (documented in the Parameters section) and rejecting unexpected kwargs (or, if you keep **kwargs, validate that kwargs is empty when the estimator doesn't consume them).
| if estimator == "akr_kband": | ||
| k = kwargs.get("k", 1) | ||
|
|
||
| if ens_w is None: | ||
| if backend == "numba": | ||
| estimator_check(estimator, energy.estimator_gufuncs) | ||
| return energy.estimator_gufuncs[estimator](obs, fct) | ||
| if estimator == "akr_kband": | ||
| return energy.estimator_gufuncs[estimator](obs, fct, k) | ||
| else: | ||
| return energy.estimator_gufuncs[estimator](obs, fct) |
There was a problem hiding this comment.
k is pulled from kwargs and then passed into the numba gufunc path without any validation. If a caller supplies k=0 (or a non-int), the gufunc implementations will hit division-by-zero (1/(M*k) or 1/k) or type errors rather than a clean ValueError like the non-numba path. Add validation here (e.g., ensure k is an int >= 1, and optionally k <= M-1) before dispatching to energy.estimator_gufuncs[...].
| for j in range(1, k + 1): | ||
| e_2 += 2 * float(np.linalg.norm(fct[i] - fct[(i + j) % M])) | ||
|
|
||
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 |
There was a problem hiding this comment.
The new numba gufunc divides by k (1/(M*k)), but there is no guard for k <= 0. Because scoringrules/_energy.py currently forwards k without validation on the numba path, k=0 will trigger division-by-zero / invalid output instead of a ValueError. Ensure k is validated in the Python wrapper before calling this gufunc (or add a safe guard if possible).
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 | |
| if k <= 0: | |
| # Avoid division by zero / invalid k; propagate a sentinel value. | |
| out[0] = np.nan | |
| else: | |
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 |
| @guvectorize("(d),(m,d),(),(m)->()") | ||
| def _energy_score_akr_kband_gufunc_w( | ||
| obs: np.ndarray, fct: np.ndarray, k: int, ens_w: np.ndarray, out: np.ndarray | ||
| ): | ||
| """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" | ||
| M = fct.shape[0] | ||
|
|
||
| e_1 = 0.0 | ||
| e_2 = 0.0 | ||
| for i in range(M): | ||
| e_1 += float(np.linalg.norm(fct[i] - obs)) * ens_w[i] | ||
| for j in range(1, k + 1): | ||
| e_2 += ( | ||
| 2 | ||
| * float(np.linalg.norm(fct[i] - fct[(i + j) % M])) | ||
| * ens_w[i] | ||
| * ens_w[(i + j) % M] | ||
| ) | ||
|
|
||
| out[0] = e_1 - 0.5 * 1 / k * e_2 | ||
|
|
There was a problem hiding this comment.
The new numba gufunc divides by k (1/k), but there is no guard for k <= 0. Since the public wrapper currently accepts k via **kwargs and doesn't validate it on the numba path, k=0 will trigger division-by-zero / invalid output. Validate k in the Python wrapper before calling this gufunc (or add a safe guard here if feasible).
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | ||
| E_1 = B.sum(err_norm, -1) / M | ||
|
|
||
| E_2 = 0.0 | ||
| for j in range(1, k + 1): | ||
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | ||
| E_2 += 2 * B.sum(spread_norm, -1) | ||
| E_2 = E_2 / (M * k) |
There was a problem hiding this comment.
k is only validated for k >= 1, but values k >= M (ensemble size) will cause repeated offsets (and shifts that effectively compare members with themselves when j is a multiple of M), while still dividing by k. That changes the estimator in a hard-to-interpret way and adds unnecessary work. Consider validating/clamping k to 1 <= k <= M-1 (or to the maximum unique band width you intend) before the loop.
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm, -1) / M | |
| E_2 = 0.0 | |
| for j in range(1, k + 1): | |
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | |
| E_2 += 2 * B.sum(spread_norm, -1) | |
| E_2 = E_2 / (M * k) | |
| # Clamp k to the maximum meaningful band width (number of unique nontrivial offsets). | |
| # For M == 1, this keeps k_eff at least 1 to avoid division by zero; spread terms are zero anyway. | |
| max_bandwidth = max(1, M - 1) | |
| k_eff = min(k, max_bandwidth) | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm, -1) / M | |
| E_2 = 0.0 | |
| for j in range(1, k_eff + 1): | |
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | |
| E_2 += 2 * B.sum(spread_norm, -1) | |
| E_2 = E_2 / (M * k_eff) |
|
|
||
| if k < 1: | ||
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | ||
|
|
||
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | ||
| E_1 = B.sum(err_norm * ens_w, -1) | ||
|
|
||
| E_2 = 0.0 | ||
| for j in range(1, k + 1): | ||
| fct_shift = B.roll(fct, shift=-j, axis=-2) | ||
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | ||
| spread_norm = B.norm(fct - fct_shift, -1) | ||
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | ||
|
|
||
| return E_1 - 0.5 * E_2 / k |
There was a problem hiding this comment.
Like the unweighted implementation, this only checks k >= 1. For k >= M (ensemble size), the rolled offsets repeat and (when j is a multiple of M) compare members with themselves, while still dividing by k. Consider validating/clamping k to 1 <= k <= M-1 (or the maximum unique band width you intend) to avoid duplicated work and hard-to-interpret scaling.
| if k < 1: | |
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm * ens_w, -1) | |
| E_2 = 0.0 | |
| for j in range(1, k + 1): | |
| fct_shift = B.roll(fct, shift=-j, axis=-2) | |
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| return E_1 - 0.5 * E_2 / k | |
| M: int = fct.shape[-2] | |
| if M < 2: | |
| raise ValueError( | |
| "For estimator='akr_kband', ensemble size M must be >= 2." | |
| ) | |
| if k < 1: | |
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | |
| # Clamp k to the maximum unique band width (M - 1) to avoid | |
| # repeated cyclic permutations and self-comparisons when k >= M. | |
| k_eff = min(k, M - 1) | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm * ens_w, -1) | |
| E_2 = 0.0 | |
| for j in range(1, k_eff + 1): | |
| fct_shift = B.roll(fct, shift=-j, axis=-2) | |
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| return E_1 - 0.5 * E_2 / k_eff |
|
I'm also generally confused about the test that tests the equivalence of the weighted/unweighted scores for weights all zero. Did this pass before for scores like the normal AKR energy score? For me it does not: Since the weights should be |
I would avoid a new top level function. I think my preference is to add k as an extra argument, after the "estimator" argument (and after "sorted_ensemble" for the CRPS). The k-band estimator should simplify to the akr estimator when k = 1, right? So perhaps this is the most sensible default? This would prevent people from using it with a arbitrary k without them understanding what k is exactly. Maybe @frazane disagrees though. |
Yes it should simplify to the AKR, I have to check, I think it uses different pairs of samples though. I think the route you suggested is sensible so I'd go with this. It keeps the defaults easy to understand and the user facing functions rather clean. |
First go at the AKR K-Band Score
A question to be decided now before we start the implementation of CRPS etc: How do we handle the additional$k$ parameter / argument to the top level function?
es_ensemble(fct, obs, ...., estimator="akr_kband", k=1)Some options I could think of:
k=some_valueas default?es_ensemble_kband(...)(not my favorite)@sallen12 , @frazane what do you think?