Skip to content

AKR-K-Band Scores#119

Open
simon-hirsch wants to merge 2 commits intofrazane:mainfrom
simon-hirsch:akrkband
Open

AKR-K-Band Scores#119
simon-hirsch wants to merge 2 commits intofrazane:mainfrom
simon-hirsch:akrkband

Conversation

@simon-hirsch
Copy link
Copy Markdown
Contributor

First go at the AKR K-Band Score

A question to be decided now before we start the implementation of CRPS etc: How do we handle the additional $k$ parameter / argument to the top level function?

es_ensemble(fct, obs, ...., estimator="akr_kband", k=1)

Some options I could think of:

  • have k=some_value as default?
  • K as kwargs parameter (more flexible, less intuitive)
  • new top level function es_ensemble_kband(...) (not my favorite)

@sallen12 , @frazane what do you think?

Copilot AI review requested due to automatic review settings March 31, 2026 14:26
@simon-hirsch simon-hirsch changed the title AKR-K-Band Energy score AKR-K-Band Scores Mar 31, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new Energy Score estimator, akr_kband, implementing an AKR-based k-band approximation across both the core (backend-agnostic) implementations and the numba gufunc dispatch path, and wires it through the public sr.es_ensemble API.

Changes:

  • Add akr_kband estimator support in core Energy Score implementations (weighted + unweighted).
  • Add numba gufunc implementations for akr_kband (weighted + unweighted) and register them in estimator maps.
  • Update public sr.es_ensemble to pass k through (currently via **kwargs) and add akr_kband to the energy test parametrization.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/test_energy.py Adds akr_kband to the estimator test matrix and threads k through exception checks.
scoringrules/core/energy/_score.py Adds akr_kband to core unweighted ES dispatch + implementation.
scoringrules/core/energy/_score_w.py Adds akr_kband to core weighted ES dispatch + implementation.
scoringrules/core/energy/_gufuncs.py Adds numba gufunc for unweighted akr_kband and registers it.
scoringrules/core/energy/_gufuncs_w.py Adds numba gufunc for weighted akr_kband and registers it.
scoringrules/_energy.py Updates public API to accept extra kwargs and route k to the appropriate implementation.
Comments suppressed due to low confidence (1)

tests/test_energy.py:14

  • akr_kband was added to ESTIMATORS, but the test’s later “correctness” assertions only cover nrg/fair/akr/akr_circperm. As a result, akr_kband gets no numeric regression coverage. Consider extending the correctness checks (and weighted correctness checks) to include an expected value for akr_kband (for a fixed k, and ideally also a small check that non-default k is honored).
ESTIMATORS = ["nrg", "fair", "akr", "akr_circperm", "akr_kband"]


@pytest.mark.parametrize("estimator", ESTIMATORS)
def test_energy_score(estimator, backend):

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +131 to +133
ens_w_shift = B.roll(ens_w, shift=-j, axis=-1)
spread_norm = B.norm(fct - fct_shift, -1)
E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_es_ensemble_akr_kband_w weights the spread term by ens_w * ens_w_shift, which makes the weighted score disagree with the unweighted score even for uniform weights (unlike the other estimators in this module). This will break the existing invariant tested in tests/test_energy.py that ens_w=np.ones(...) matches the unweighted result. Consider aligning with _es_ensemble_akr_w / _es_ensemble_akr_circperm_w by weighting the k-band spread term with ens_w only (or otherwise adjust the normalization so uniform weights reproduce the unweighted akr_kband result).

Suggested change
ens_w_shift = B.roll(ens_w, shift=-j, axis=-1)
spread_norm = B.norm(fct - fct_shift, -1)
E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1)
spread_norm = B.norm(fct - fct_shift, -1)
# Weight the spread term with ens_w only, consistent with other AKR estimators.
E_2 += 2 * B.sum(spread_norm * ens_w, -1)

Copilot uses AI. Check for mistakes.
2
* float(np.linalg.norm(fct[i] - fct[(i + j) % M]))
* ens_w[i]
* ens_w[(i + j) % M]
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new _energy_score_akr_kband_gufunc_w multiplies by ens_w[i] * ens_w[(i+j)%M], but the existing weighted AKR/circperm estimators in this file weight only by ens_w[i]. With ens_w uniform, this changes the scale by ~1/M compared to the unweighted akr_kband, and will fail the test that uniform weights reproduce the unweighted score. Either adjust the weighting (e.g., weight only by ens_w[i] like the other AKR variants) or adjust the unweighted implementation/normalization so the two agree for uniform weights.

Suggested change
* ens_w[(i + j) % M]

Copilot uses AI. Check for mistakes.
Comment on lines 16 to 26
def es_ensemble(
obs: "Array",
fct: "Array",
m_axis: int = -2,
v_axis: int = -1,
*,
ens_w: "Array" = None,
estimator: str = "nrg",
backend: "Backend" = None,
**kwargs,
) -> "Array":
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

es_ensemble now accepts arbitrary **kwargs, but the docstring doesn't document any additional keyword arguments and (for non-akr_kband estimators) extra kwargs are silently ignored. This makes it easy for user typos to go unnoticed. Prefer adding an explicit keyword-only k: int = 1 parameter (documented in the Parameters section) and rejecting unexpected kwargs (or, if you keep **kwargs, validate that kwargs is empty when the estimator doesn't consume them).

Copilot uses AI. Check for mistakes.
Comment on lines +76 to +85
if estimator == "akr_kband":
k = kwargs.get("k", 1)

if ens_w is None:
if backend == "numba":
estimator_check(estimator, energy.estimator_gufuncs)
return energy.estimator_gufuncs[estimator](obs, fct)
if estimator == "akr_kband":
return energy.estimator_gufuncs[estimator](obs, fct, k)
else:
return energy.estimator_gufuncs[estimator](obs, fct)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k is pulled from kwargs and then passed into the numba gufunc path without any validation. If a caller supplies k=0 (or a non-int), the gufunc implementations will hit division-by-zero (1/(M*k) or 1/k) or type errors rather than a clean ValueError like the non-numba path. Add validation here (e.g., ensure k is an int >= 1, and optionally k <= M-1) before dispatching to energy.estimator_gufuncs[...].

Copilot uses AI. Check for mistakes.
for j in range(1, k + 1):
e_2 += 2 * float(np.linalg.norm(fct[i] - fct[(i + j) % M]))

out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new numba gufunc divides by k (1/(M*k)), but there is no guard for k <= 0. Because scoringrules/_energy.py currently forwards k without validation on the numba path, k=0 will trigger division-by-zero / invalid output instead of a ValueError. Ensure k is validated in the Python wrapper before calling this gufunc (or add a safe guard if possible).

Suggested change
out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2
if k <= 0:
# Avoid division by zero / invalid k; propagate a sentinel value.
out[0] = np.nan
else:
out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2

Copilot uses AI. Check for mistakes.
Comment on lines +82 to +102
@guvectorize("(d),(m,d),(),(m)->()")
def _energy_score_akr_kband_gufunc_w(
obs: np.ndarray, fct: np.ndarray, k: int, ens_w: np.ndarray, out: np.ndarray
):
"""Compute the Energy Score for a finite ensemble using the AKR with k-band approximation."""
M = fct.shape[0]

e_1 = 0.0
e_2 = 0.0
for i in range(M):
e_1 += float(np.linalg.norm(fct[i] - obs)) * ens_w[i]
for j in range(1, k + 1):
e_2 += (
2
* float(np.linalg.norm(fct[i] - fct[(i + j) % M]))
* ens_w[i]
* ens_w[(i + j) % M]
)

out[0] = e_1 - 0.5 * 1 / k * e_2

Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new numba gufunc divides by k (1/k), but there is no guard for k <= 0. Since the public wrapper currently accepts k via **kwargs and doesn't validate it on the numba path, k=0 will trigger division-by-zero / invalid output. Validate k in the Python wrapper before calling this gufunc (or add a safe guard here if feasible).

Copilot uses AI. Check for mistakes.
Comment on lines +108 to +115
err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm, -1) / M

E_2 = 0.0
for j in range(1, k + 1):
spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1)
E_2 += 2 * B.sum(spread_norm, -1)
E_2 = E_2 / (M * k)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k is only validated for k >= 1, but values k >= M (ensemble size) will cause repeated offsets (and shifts that effectively compare members with themselves when j is a multiple of M), while still dividing by k. That changes the estimator in a hard-to-interpret way and adds unnecessary work. Consider validating/clamping k to 1 <= k <= M-1 (or to the maximum unique band width you intend) before the loop.

Suggested change
err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm, -1) / M
E_2 = 0.0
for j in range(1, k + 1):
spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1)
E_2 += 2 * B.sum(spread_norm, -1)
E_2 = E_2 / (M * k)
# Clamp k to the maximum meaningful band width (number of unique nontrivial offsets).
# For M == 1, this keeps k_eff at least 1 to avoid division by zero; spread terms are zero anyway.
max_bandwidth = max(1, M - 1)
k_eff = min(k, max_bandwidth)
err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm, -1) / M
E_2 = 0.0
for j in range(1, k_eff + 1):
spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1)
E_2 += 2 * B.sum(spread_norm, -1)
E_2 = E_2 / (M * k_eff)

Copilot uses AI. Check for mistakes.
Comment on lines +121 to +135

if k < 1:
raise ValueError("For estimator='akr_kband', k must be >= 1.")

err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm * ens_w, -1)

E_2 = 0.0
for j in range(1, k + 1):
fct_shift = B.roll(fct, shift=-j, axis=-2)
ens_w_shift = B.roll(ens_w, shift=-j, axis=-1)
spread_norm = B.norm(fct - fct_shift, -1)
E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1)

return E_1 - 0.5 * E_2 / k
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the unweighted implementation, this only checks k >= 1. For k >= M (ensemble size), the rolled offsets repeat and (when j is a multiple of M) compare members with themselves, while still dividing by k. Consider validating/clamping k to 1 <= k <= M-1 (or the maximum unique band width you intend) to avoid duplicated work and hard-to-interpret scaling.

Suggested change
if k < 1:
raise ValueError("For estimator='akr_kband', k must be >= 1.")
err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm * ens_w, -1)
E_2 = 0.0
for j in range(1, k + 1):
fct_shift = B.roll(fct, shift=-j, axis=-2)
ens_w_shift = B.roll(ens_w, shift=-j, axis=-1)
spread_norm = B.norm(fct - fct_shift, -1)
E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1)
return E_1 - 0.5 * E_2 / k
M: int = fct.shape[-2]
if M < 2:
raise ValueError(
"For estimator='akr_kband', ensemble size M must be >= 2."
)
if k < 1:
raise ValueError("For estimator='akr_kband', k must be >= 1.")
# Clamp k to the maximum unique band width (M - 1) to avoid
# repeated cyclic permutations and self-comparisons when k >= M.
k_eff = min(k, M - 1)
err_norm = B.norm(fct - B.expand_dims(obs, -2), -1)
E_1 = B.sum(err_norm * ens_w, -1)
E_2 = 0.0
for j in range(1, k_eff + 1):
fct_shift = B.roll(fct, shift=-j, axis=-2)
ens_w_shift = B.roll(ens_w, shift=-j, axis=-1)
spread_norm = B.norm(fct - fct_shift, -1)
E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1)
return E_1 - 0.5 * E_2 / k_eff

Copilot uses AI. Check for mistakes.
@simon-hirsch
Copy link
Copy Markdown
Contributor Author

I'm also generally confused about the test that tests the equivalence of the weighted/unweighted scores for weights all zero. Did this pass before for scores like the normal AKR energy score? For me it does not:

akr = _energy_score_akr_gufunc(
    obs,
    fct,
    np.zeros(N),
)
akr_w = _energy_score_akr_gufunc_w(
    obs,
    fct,
    w,
    np.zeros(N),
)
np.allclose(akr, akr_w)

Since the weights should be $1/M$ where $M$ is the ensemble size for the test to pass

@sallen12
Copy link
Copy Markdown
Collaborator

First go at the AKR K-Band Score

A question to be decided now before we start the implementation of CRPS etc: How do we handle the additional k parameter / argument to the top level function?

es_ensemble(fct, obs, ...., estimator="akr_kband", k=1)

Some options I could think of:

  • have k=some_value as default?
  • K as kwargs parameter (more flexible, less intuitive)
  • new top level function es_ensemble_kband(...) (not my favorite)

@sallen12 , @frazane what do you think?

I would avoid a new top level function. I think my preference is to add k as an extra argument, after the "estimator" argument (and after "sorted_ensemble" for the CRPS). The k-band estimator should simplify to the akr estimator when k = 1, right? So perhaps this is the most sensible default? This would prevent people from using it with a arbitrary k without them understanding what k is exactly.

Maybe @frazane disagrees though.

@simon-hirsch
Copy link
Copy Markdown
Contributor Author

simon-hirsch commented Apr 3, 2026

I would avoid a new top level function. I think my preference is to add k as an extra argument, after the "estimator" argument (and after "sorted_ensemble" for the CRPS). The k-band estimator should simplify to the akr estimator when k = 1, right? So perhaps this is the most sensible default? This would prevent people from using it with a arbitrary k without them understanding what k is exactly.

Yes it should simplify to the AKR, I have to check, I think it uses different pairs of samples though.

I think the route you suggested is sensible so I'd go with this. It keeps the defaults easy to understand and the user facing functions rather clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants