Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
uv_weighted_score_weights,
uv_weighted_score_chain,
univariate_sort_ens,
nan_policy_check,
apply_nan_policy_ens_uv,
)

if tp.TYPE_CHECKING:
from scoringrules.core.typing import Array, ArrayLike, Backend
from scoringrules.core.typing import Array, ArrayLike, Backend, NanPolicy


def crps_ensemble(
Expand All @@ -21,6 +23,7 @@ def crps_ensemble(
*,
sorted_ensemble: bool = False,
estimator: str = "qd",
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the Continuous Ranked Probability Score (CRPS) for a finite ensemble.
Expand Down Expand Up @@ -61,6 +64,15 @@ def crps_ensemble(
Default is False.
estimator : str
Indicates the CRPS estimator to be used.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these three options

Defines how to handle NaN values in the ensemble members:

- ``'propagate'``: return NaN if any ensemble member is NaN.
- ``'omit'``: ignore NaN ensemble members during computation.
- ``'raise'``: raise a ValueError if NaN values are encountered.

Note: this applies to ensemble members (fct) only. NaN values in
observations (obs) always result in NaN output.
backend : str, optional
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.

Expand Down Expand Up @@ -102,13 +114,27 @@ def crps_ensemble(
>>> sr.crps_ensemble(obs, pred)
array([0.69605316, 0.32865417, 0.39048665])
"""
nan_policy_check(nan_policy)
obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend)
# Sort before NaN handling so that NaN members end up at the tail of sorted arrays,
# which is required for rank-based estimators ('pwm', 'qd', 'int').
fct = univariate_sort_ens(fct, estimator, sorted_ensemble, backend=backend)
if backend == "numba":
estimator_check(estimator, crps.estimator_gufuncs)
if nan_policy == "raise":
# Validate only — raises if NaN found; NaN values are left in-place.
apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend)
if nan_policy == "omit":
# NaN values are left in-place; the nanomit gufuncs skip them internally.
return crps.estimator_gufuncs_nanomit[estimator](obs, fct)
return crps.estimator_gufuncs[estimator](obs, fct)
else:
return crps.ensemble(obs, fct, estimator, backend=backend)
obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend)
if nan_policy == "omit" and estimator == "int":
raise NotImplementedError(
"nan_policy='omit' is not supported with estimator='int' on non-numba backends. "
"Use a different estimator such as 'nrg', 'fair', 'pwm', or 'qd'."
)
return crps.ensemble(obs, fct, estimator, nan_mask=nan_mask, backend=backend)


def twcrps_ensemble(
Expand All @@ -121,6 +147,7 @@ def twcrps_ensemble(
v_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
estimator: str = "qd",
sorted_ensemble: bool = False,
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the threshold-weighted CRPS (twCRPS) for a finite ensemble.
Expand Down Expand Up @@ -155,6 +182,9 @@ def twcrps_ensemble(
Chaining function used to emphasise particular outcomes. For example, a function that
only considers values above a certain threshold :math:`t` by projecting forecasts and observations
to :math:`[t, \inf)`.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
Defines how to handle NaN values in the ensemble members. Forwarded to
:func:`crps_ensemble`. See its documentation for details.
backend : str, optional
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.

Expand Down Expand Up @@ -202,6 +232,7 @@ def twcrps_ensemble(
m_axis=m_axis,
sorted_ensemble=sorted_ensemble,
estimator=estimator,
nan_policy=nan_policy,
backend=backend,
)

Expand All @@ -214,6 +245,7 @@ def owcrps_ensemble(
m_axis: int = -1,
*,
w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the outcome-weighted CRPS (owCRPS) for a finite ensemble.
Expand Down Expand Up @@ -252,6 +284,10 @@ def owcrps_ensemble(
The axis corresponding to the ensemble. Default is the last axis.
w_func : callable, array_like -> array_like
Weight function used to emphasise particular outcomes.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
Defines how to handle NaN values in the ensemble members. Applied before
weight computation so NaN members do not contribute to the mean weight.
See :func:`crps_ensemble` for details.
backend : str, optional
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.

Expand Down Expand Up @@ -286,12 +322,22 @@ def owcrps_ensemble(
>>> sr.owcrps_ensemble(obs, fct, w_func=w_func)
array([0.91103733, 0.45212402, 0.35686667])
"""
nan_policy_check(nan_policy)
obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend)
obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend)
if backend == "numba":
if nan_policy == "raise":
apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend)
obs_w, fct_w = uv_weighted_score_weights(
obs, fct, a, b, w_func, backend=backend
)
if nan_policy == "omit":
# NaN values are left in-place; the nanomit gufunc skips them and
# recomputes wbar from valid members only.
return crps.estimator_gufuncs_nanomit["ownrg"](obs, fct, obs_w, fct_w)
return crps.estimator_gufuncs["ownrg"](obs, fct, obs_w, fct_w)
else:
return crps.ow_ensemble(obs, fct, obs_w, fct_w, backend=backend)
obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend)
obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend)
return crps.ow_ensemble(obs, fct, obs_w, fct_w, nan_mask=nan_mask, backend=backend)


def vrcrps_ensemble(
Expand All @@ -302,6 +348,7 @@ def vrcrps_ensemble(
m_axis: int = -1,
*,
w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the vertically re-scaled CRPS (vrCRPS) for a finite ensemble.
Expand Down Expand Up @@ -338,6 +385,10 @@ def vrcrps_ensemble(
The axis corresponding to the ensemble. Default is the last axis.
w_func : callable, array_like -> array_like
Weight function used to emphasise particular outcomes.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
Defines how to handle NaN values in the ensemble members. Applied before
weight computation so NaN members do not contribute to the mean weight.
See :func:`crps_ensemble` for details.
backend : str, optional
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.

Expand Down Expand Up @@ -372,12 +423,20 @@ def vrcrps_ensemble(
>>> sr.vrcrps_ensemble(obs, fct, w_func)
array([0.90036433, 0.41515255, 0.41653833])
"""
nan_policy_check(nan_policy)
obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend)
obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend)
if backend == "numba":
if nan_policy == "raise":
apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend)
obs_w, fct_w = uv_weighted_score_weights(
obs, fct, a, b, w_func, backend=backend
)
if nan_policy == "omit":
return crps.estimator_gufuncs_nanomit["vrnrg"](obs, fct, obs_w, fct_w)
return crps.estimator_gufuncs["vrnrg"](obs, fct, obs_w, fct_w)
else:
return crps.vr_ensemble(obs, fct, obs_w, fct_w, backend=backend)
obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend)
obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend)
return crps.vr_ensemble(obs, fct, obs_w, fct_w, nan_mask=nan_mask, backend=backend)


def crps_quantile(
Expand Down
8 changes: 7 additions & 1 deletion scoringrules/core/crps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@
)

try:
from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc
from ._gufuncs import (
estimator_gufuncs,
estimator_gufuncs_nanomit,
quantile_pinball_gufunc,
)
except ImportError:
estimator_gufuncs = None
estimator_gufuncs_nanomit = None
quantile_pinball_gufunc = None

__all__ = [
Expand Down Expand Up @@ -61,6 +66,7 @@
"t",
"uniform",
"estimator_gufuncs",
"estimator_gufuncs_nanomit",
"quantile_pinball",
"quantile_pinball_gufunc",
]
Loading
Loading