diff --git a/scoringrules/_crps.py b/scoringrules/_crps.py index 0800520..7b65a96 100644 --- a/scoringrules/_crps.py +++ b/scoringrules/_crps.py @@ -8,10 +8,12 @@ uv_weighted_score_weights, uv_weighted_score_chain, univariate_sort_ens, + nan_policy_check, + apply_nan_policy_ens_uv, ) if tp.TYPE_CHECKING: - from scoringrules.core.typing import Array, ArrayLike, Backend + from scoringrules.core.typing import Array, ArrayLike, Backend, NanPolicy def crps_ensemble( @@ -21,6 +23,7 @@ def crps_ensemble( *, sorted_ensemble: bool = False, estimator: str = "qd", + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Estimate the Continuous Ranked Probability Score (CRPS) for a finite ensemble. @@ -61,6 +64,15 @@ def crps_ensemble( Default is False. estimator : str Indicates the CRPS estimator to be used. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + Defines how to handle NaN values in the ensemble members: + + - ``'propagate'``: return NaN if any ensemble member is NaN. + - ``'omit'``: ignore NaN ensemble members during computation. + - ``'raise'``: raise a ValueError if NaN values are encountered. + + Note: this applies to ensemble members (fct) only. NaN values in + observations (obs) always result in NaN output. backend : str, optional The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``. @@ -102,13 +114,27 @@ def crps_ensemble( >>> sr.crps_ensemble(obs, pred) array([0.69605316, 0.32865417, 0.39048665]) """ + nan_policy_check(nan_policy) obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend) + # Sort before NaN handling so that NaN members end up at the tail of sorted arrays, + # which is required for rank-based estimators ('pwm', 'qd', 'int'). fct = univariate_sort_ens(fct, estimator, sorted_ensemble, backend=backend) if backend == "numba": estimator_check(estimator, crps.estimator_gufuncs) + if nan_policy == "raise": + # Validate only — raises if NaN found; NaN values are left in-place. + apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend) + if nan_policy == "omit": + # NaN values are left in-place; the nanomit gufuncs skip them internally. + return crps.estimator_gufuncs_nanomit[estimator](obs, fct) return crps.estimator_gufuncs[estimator](obs, fct) - else: - return crps.ensemble(obs, fct, estimator, backend=backend) + obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend) + if nan_policy == "omit" and estimator == "int": + raise NotImplementedError( + "nan_policy='omit' is not supported with estimator='int' on non-numba backends. " + "Use a different estimator such as 'nrg', 'fair', 'pwm', or 'qd'." + ) + return crps.ensemble(obs, fct, estimator, nan_mask=nan_mask, backend=backend) def twcrps_ensemble( @@ -121,6 +147,7 @@ def twcrps_ensemble( v_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None, estimator: str = "qd", sorted_ensemble: bool = False, + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Estimate the threshold-weighted CRPS (twCRPS) for a finite ensemble. @@ -155,6 +182,9 @@ def twcrps_ensemble( Chaining function used to emphasise particular outcomes. For example, a function that only considers values above a certain threshold :math:`t` by projecting forecasts and observations to :math:`[t, \inf)`. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + Defines how to handle NaN values in the ensemble members. Forwarded to + :func:`crps_ensemble`. See its documentation for details. backend : str, optional The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``. @@ -202,6 +232,7 @@ def twcrps_ensemble( m_axis=m_axis, sorted_ensemble=sorted_ensemble, estimator=estimator, + nan_policy=nan_policy, backend=backend, ) @@ -214,6 +245,7 @@ def owcrps_ensemble( m_axis: int = -1, *, w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None, + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Estimate the outcome-weighted CRPS (owCRPS) for a finite ensemble. @@ -252,6 +284,10 @@ def owcrps_ensemble( The axis corresponding to the ensemble. Default is the last axis. w_func : callable, array_like -> array_like Weight function used to emphasise particular outcomes. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + Defines how to handle NaN values in the ensemble members. Applied before + weight computation so NaN members do not contribute to the mean weight. + See :func:`crps_ensemble` for details. backend : str, optional The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``. @@ -286,12 +322,22 @@ def owcrps_ensemble( >>> sr.owcrps_ensemble(obs, fct, w_func=w_func) array([0.91103733, 0.45212402, 0.35686667]) """ + nan_policy_check(nan_policy) obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend) - obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend) if backend == "numba": + if nan_policy == "raise": + apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend) + obs_w, fct_w = uv_weighted_score_weights( + obs, fct, a, b, w_func, backend=backend + ) + if nan_policy == "omit": + # NaN values are left in-place; the nanomit gufunc skips them and + # recomputes wbar from valid members only. + return crps.estimator_gufuncs_nanomit["ownrg"](obs, fct, obs_w, fct_w) return crps.estimator_gufuncs["ownrg"](obs, fct, obs_w, fct_w) - else: - return crps.ow_ensemble(obs, fct, obs_w, fct_w, backend=backend) + obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend) + obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend) + return crps.ow_ensemble(obs, fct, obs_w, fct_w, nan_mask=nan_mask, backend=backend) def vrcrps_ensemble( @@ -302,6 +348,7 @@ def vrcrps_ensemble( m_axis: int = -1, *, w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None, + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Estimate the vertically re-scaled CRPS (vrCRPS) for a finite ensemble. @@ -338,6 +385,10 @@ def vrcrps_ensemble( The axis corresponding to the ensemble. Default is the last axis. w_func : callable, array_like -> array_like Weight function used to emphasise particular outcomes. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + Defines how to handle NaN values in the ensemble members. Applied before + weight computation so NaN members do not contribute to the mean weight. + See :func:`crps_ensemble` for details. backend : str, optional The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``. @@ -372,12 +423,20 @@ def vrcrps_ensemble( >>> sr.vrcrps_ensemble(obs, fct, w_func) array([0.90036433, 0.41515255, 0.41653833]) """ + nan_policy_check(nan_policy) obs, fct = univariate_array_check(obs, fct, m_axis, backend=backend) - obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend) if backend == "numba": + if nan_policy == "raise": + apply_nan_policy_ens_uv(obs, fct, "raise", backend=backend) + obs_w, fct_w = uv_weighted_score_weights( + obs, fct, a, b, w_func, backend=backend + ) + if nan_policy == "omit": + return crps.estimator_gufuncs_nanomit["vrnrg"](obs, fct, obs_w, fct_w) return crps.estimator_gufuncs["vrnrg"](obs, fct, obs_w, fct_w) - else: - return crps.vr_ensemble(obs, fct, obs_w, fct_w, backend=backend) + obs, fct, nan_mask = apply_nan_policy_ens_uv(obs, fct, nan_policy, backend=backend) + obs_w, fct_w = uv_weighted_score_weights(obs, fct, a, b, w_func, backend=backend) + return crps.vr_ensemble(obs, fct, obs_w, fct_w, nan_mask=nan_mask, backend=backend) def crps_quantile( diff --git a/scoringrules/core/crps/__init__.py b/scoringrules/core/crps/__init__.py index ddeae70..7ab8970 100644 --- a/scoringrules/core/crps/__init__.py +++ b/scoringrules/core/crps/__init__.py @@ -27,9 +27,14 @@ ) try: - from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc + from ._gufuncs import ( + estimator_gufuncs, + estimator_gufuncs_nanomit, + quantile_pinball_gufunc, + ) except ImportError: estimator_gufuncs = None + estimator_gufuncs_nanomit = None quantile_pinball_gufunc = None __all__ = [ @@ -61,6 +66,7 @@ "t", "uniform", "estimator_gufuncs", + "estimator_gufuncs_nanomit", "quantile_pinball", "quantile_pinball_gufunc", ] diff --git a/scoringrules/core/crps/_approx.py b/scoringrules/core/crps/_approx.py index 3f282a2..ae5f767 100644 --- a/scoringrules/core/crps/_approx.py +++ b/scoringrules/core/crps/_approx.py @@ -10,21 +10,22 @@ def ensemble( obs: "ArrayLike", fct: "Array", estimator: str = "pwm", + nan_mask=None, backend: "Backend" = None, ) -> "Array": """Compute the CRPS for a finite ensemble.""" if estimator == "nrg": - out = _crps_ensemble_nrg(obs, fct, backend=backend) + out = _crps_ensemble_nrg(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "pwm": - out = _crps_ensemble_pwm(obs, fct, backend=backend) + out = _crps_ensemble_pwm(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "fair": - out = _crps_ensemble_fair(obs, fct, backend=backend) + out = _crps_ensemble_fair(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "qd": - out = _crps_ensemble_qd(obs, fct, backend=backend) + out = _crps_ensemble_qd(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "akr": - out = _crps_ensemble_akr(obs, fct, backend=backend) + out = _crps_ensemble_akr(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "akr_circperm": - out = _crps_ensemble_akr_circperm(obs, fct, backend=backend) + out = _crps_ensemble_akr_circperm(obs, fct, nan_mask=nan_mask, backend=backend) elif estimator == "int": out = _crps_ensemble_int(obs, fct, backend=backend) else: @@ -35,74 +36,195 @@ def ensemble( def _crps_ensemble_fair( - obs: "Array", fct: "Array", backend: "Backend" = None + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None ) -> "Array": """Fair version of the CRPS estimator based on the energy form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M - e_2 = B.sum( - B.abs(fct[..., None] - fct[..., None, :]), - axis=(-1, -2), - ) / (M * (M - 1)) - return e_1 - 0.5 * e_2 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct)), axis=-1 + ) + / M_eff + ) + pair_mask = nan_mask[..., :, None] | nan_mask[..., None, :] + e_2 = B.sum( + B.where( + pair_mask, + B.asarray(0.0), + B.abs(fct[..., None] - fct[..., None, :]), + ), + axis=(-1, -2), + ) / (M_eff * (M_eff - 1)) + result = e_1 - 0.5 * e_2 + return B.where(M_eff <= 1, B.asarray(float("nan")), result) + else: + e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M + e_2 = B.sum( + B.abs(fct[..., None] - fct[..., None, :]), + axis=(-1, -2), + ) / (M * (M - 1)) + return e_1 - 0.5 * e_2 def _crps_ensemble_nrg( - obs: "Array", fct: "Array", backend: "Backend" = None + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None ) -> "Array": """CRPS estimator based on the energy form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M - e_2 = B.sum(B.abs(fct[..., None] - fct[..., None, :]), (-1, -2)) / (M**2) - return e_1 - 0.5 * e_2 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct)), axis=-1 + ) + / M_eff + ) + pair_mask = nan_mask[..., :, None] | nan_mask[..., None, :] + e_2 = ( + B.sum( + B.where( + pair_mask, + B.asarray(0.0), + B.abs(fct[..., None] - fct[..., None, :]), + ), + (-1, -2), + ) + / M_eff**2 + ) + result = e_1 - 0.5 * e_2 + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M + e_2 = B.sum(B.abs(fct[..., None] - fct[..., None, :]), (-1, -2)) / (M**2) + return e_1 - 0.5 * e_2 def _crps_ensemble_pwm( - obs: "Array", fct: "Array", backend: "Backend" = None + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None ) -> "Array": """CRPS estimator based on the probability weighted moment (PWM) form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - expected_diff = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M - β_0 = B.sum(fct, axis=-1) / M - β_1 = B.sum(fct * B.arange(0, M), axis=-1) / (M * (M - 1.0)) - return expected_diff + β_0 - 2.0 * β_1 + if nan_mask is not None: + # Assumes the ensemble is sorted with NaN members zeroed at the end. + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + expected_diff = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct)), axis=-1 + ) + / M_eff + ) + # NaN members are zeroed so their contributions to β_0 and β_1 are 0. + β_0 = B.sum(fct, axis=-1) / M_eff + β_1 = B.sum(fct * B.arange(0, M), axis=-1) / (M_eff * (M_eff - 1.0)) + result = expected_diff + β_0 - 2.0 * β_1 + return B.where(M_eff <= 1, B.asarray(float("nan")), result) + else: + expected_diff = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M + β_0 = B.sum(fct, axis=-1) / M + β_1 = B.sum(fct * B.arange(0, M), axis=-1) / (M * (M - 1.0)) + return expected_diff + β_0 - 2.0 * β_1 def _crps_ensemble_akr( - obs: "Array", fct: "Array", backend: "Backend" = None + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None ) -> "Array": """CRPS estimator based on the approximate kernel representation.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M - e_2 = B.sum(B.abs(fct - B.roll(fct, shift=1, axis=-1)), -1) / M - return e_1 - 0.5 * e_2 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct)), axis=-1 + ) + / M_eff + ) + roll_mask = nan_mask | B.roll(nan_mask, shift=1, axis=-1) + e_2 = ( + B.sum( + B.where( + roll_mask, + B.asarray(0.0), + B.abs(fct - B.roll(fct, shift=1, axis=-1)), + ), + -1, + ) + / M_eff + ) + result = e_1 - 0.5 * e_2 + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M + e_2 = B.sum(B.abs(fct - B.roll(fct, shift=1, axis=-1)), -1) / M + return e_1 - 0.5 * e_2 def _crps_ensemble_akr_circperm( - obs: "Array", fct: "Array", backend: "Backend" = None + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None ) -> "Array": """CRPS estimator based on the AKR with cyclic permutation.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M - shift = M // 2 - e_2 = B.sum(B.abs(fct - B.roll(fct, shift=shift, axis=-1)), -1) / M - return e_1 - 0.5 * e_2 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct)), axis=-1 + ) + / M_eff + ) + shift = M // 2 + roll_mask = nan_mask | B.roll(nan_mask, shift=shift, axis=-1) + e_2 = ( + B.sum( + B.where( + roll_mask, + B.asarray(0.0), + B.abs(fct - B.roll(fct, shift=shift, axis=-1)), + ), + -1, + ) + / M_eff + ) + result = e_1 - 0.5 * e_2 + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + e_1 = B.sum(B.abs(obs[..., None] - fct), axis=-1) / M + shift = M // 2 + e_2 = B.sum(B.abs(fct - B.roll(fct, shift=shift, axis=-1)), -1) / M + return e_1 - 0.5 * e_2 -def _crps_ensemble_qd(obs: "Array", fct: "Array", backend: "Backend" = None) -> "Array": +def _crps_ensemble_qd( + obs: "Array", fct: "Array", nan_mask=None, backend: "Backend" = None +) -> "Array": """CRPS estimator based on the quantile decomposition form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - alpha = B.arange(1, M + 1) - 0.5 - below = (fct <= obs[..., None]) * alpha * (obs[..., None] - fct) - above = (fct > obs[..., None]) * (M - alpha) * (fct - obs[..., None]) - out = B.sum(below + above, axis=-1) / (M**2) - return 2 * out + if nan_mask is not None: + # Assumes the ensemble is sorted with NaN members zeroed at the end. + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + alpha = B.arange(1, M + 1) - 0.5 # shape (M,) + below = (fct <= obs[..., None]) * alpha * (obs[..., None] - fct) + above = ( + (fct > obs[..., None]) * (M_eff[..., None] - alpha) * (fct - obs[..., None]) + ) + out = ( + B.sum(B.where(nan_mask, B.asarray(0.0), below + above), axis=-1) / M_eff**2 + ) + result = 2 * out + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + alpha = B.arange(1, M + 1) - 0.5 + below = (fct <= obs[..., None]) * alpha * (obs[..., None] - fct) + above = (fct > obs[..., None]) * (M - alpha) * (fct - obs[..., None]) + out = B.sum(below + above, axis=-1) / (M**2) + return 2 * out def _crps_ensemble_int( @@ -138,19 +260,49 @@ def ow_ensemble( fct: "Array", ow: "Array", fw: "Array", + nan_mask=None, backend: "Backend" = None, ) -> "Array": """Outcome-Weighted CRPS estimator based on the energy form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - wbar = B.mean(fw, axis=-1) - e_1 = B.sum(B.abs(obs[..., None] - fct) * fw, axis=-1) * ow / (M * wbar) - e_2 = B.sum( - B.abs(fct[..., None] - fct[..., None, :]) * fw[..., None] * fw[..., None, :], - axis=(-1, -2), - ) - e_2 *= ow / (M**2 * wbar**2) - return e_1 - 0.5 * e_2 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + fw = B.where(nan_mask, B.asarray(0.0), fw) + wbar = B.sum(fw, axis=-1) / M_eff + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct) * fw), + axis=-1, + ) + * ow + / (M_eff * wbar) + ) + pair_mask = nan_mask[..., :, None] | nan_mask[..., None, :] + e_2 = B.sum( + B.where( + pair_mask, + B.asarray(0.0), + B.abs(fct[..., None] - fct[..., None, :]) + * fw[..., None] + * fw[..., None, :], + ), + axis=(-1, -2), + ) + e_2 *= ow / (M_eff**2 * wbar**2) + result = e_1 - 0.5 * e_2 + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + wbar = B.mean(fw, axis=-1) + e_1 = B.sum(B.abs(obs[..., None] - fct) * fw, axis=-1) * ow / (M * wbar) + e_2 = B.sum( + B.abs(fct[..., None] - fct[..., None, :]) + * fw[..., None] + * fw[..., None, :], + axis=(-1, -2), + ) + e_2 *= ow / (M**2 * wbar**2) + return e_1 - 0.5 * e_2 def vr_ensemble( @@ -158,17 +310,50 @@ def vr_ensemble( fct: "Array", ow: "Array", fw: "Array", + nan_mask=None, backend: "Backend" = None, ) -> "Array": """Vertically Re-scaled CRPS estimator based on the energy form.""" B = backends.active if backend is None else backends[backend] M: int = fct.shape[-1] - e_1 = B.sum(B.abs(obs[..., None] - fct) * fw, axis=-1) * ow / M - e_2 = B.sum( - B.abs(B.expand_dims(fct, axis=-1) - B.expand_dims(fct, axis=-2)) - * (B.expand_dims(fw, axis=-1) * B.expand_dims(fw, axis=-2)), - axis=(-1, -2), - ) / (M**2) - e_3 = B.mean(B.abs(fct) * fw, axis=-1) - B.abs(obs) * ow - e_3 *= B.mean(fw, axis=1) - ow - return e_1 - 0.5 * e_2 + e_3 + if nan_mask is not None: + M_eff = B.sum(B.where(nan_mask, B.asarray(0.0), B.asarray(1.0)), axis=-1) + fw = B.where(nan_mask, B.asarray(0.0), fw) + e_1 = ( + B.sum( + B.where(nan_mask, B.asarray(0.0), B.abs(obs[..., None] - fct) * fw), + axis=-1, + ) + * ow + / M_eff + ) + pair_mask = nan_mask[..., :, None] | nan_mask[..., None, :] + e_2 = ( + B.sum( + B.where( + pair_mask, + B.asarray(0.0), + B.abs(B.expand_dims(fct, axis=-1) - B.expand_dims(fct, axis=-2)) + * (B.expand_dims(fw, axis=-1) * B.expand_dims(fw, axis=-2)), + ), + axis=(-1, -2), + ) + / M_eff**2 + ) + e_3 = ( + B.sum(B.where(nan_mask, B.asarray(0.0), B.abs(fct) * fw), axis=-1) / M_eff + - B.abs(obs) * ow + ) + e_3 *= B.sum(fw, axis=-1) / M_eff - ow + result = e_1 - 0.5 * e_2 + e_3 + return B.where(M_eff == 0, B.asarray(float("nan")), result) + else: + e_1 = B.sum(B.abs(obs[..., None] - fct) * fw, axis=-1) * ow / M + e_2 = B.sum( + B.abs(B.expand_dims(fct, axis=-1) - B.expand_dims(fct, axis=-2)) + * (B.expand_dims(fw, axis=-1) * B.expand_dims(fw, axis=-2)), + axis=(-1, -2), + ) / (M**2) + e_3 = B.mean(B.abs(fct) * fw, axis=-1) - B.abs(obs) * ow + e_3 *= B.mean(fw, axis=1) - ow + return e_1 - 0.5 * e_2 + e_3 diff --git a/scoringrules/core/crps/_gufuncs.py b/scoringrules/core/crps/_gufuncs.py index e3ecc2b..b42e776 100644 --- a/scoringrules/core/crps/_gufuncs.py +++ b/scoringrules/core/crps/_gufuncs.py @@ -169,7 +169,7 @@ def _crps_ensemble_akr_gufunc(obs: np.ndarray, fct: np.ndarray, out: np.ndarray) e_2 = 0.0 for i, forecast in enumerate(fct): if i == 0: - i = M - 1 + i = M e_1 += abs(forecast - obs) e_2 += abs(forecast - fct[i - 1]) out[0] = e_1 / M - 0.5 * 1 / M * e_2 @@ -260,6 +260,396 @@ def _vrcrps_ensemble_nrg_gufunc( "vrnrg": lazy_gufunc_wrapper_uv(_vrcrps_ensemble_nrg_gufunc), } + +# --- NaN-omit variants --- +# Each gufunc below checks np.isnan inline and skips invalid members. NaN +# values are left in-place by the public API (the zeroing done by +# apply_nan_policy_ens_uv is bypassed for the numba path), so the gufuncs must +# detect them themselves. +# +# For the sorted-ensemble estimators (qd, pwm, int) the public API layer +# pre-sorts the ensemble, which places NaN members at the tail (IEEE 754 +# behaviour). Those gufuncs exploit this with an early-exit loop. +# All other estimators (nrg, fair, akr, akr_circperm, ownrg, vrnrg) make no +# assumption about member order and scan the full array. + + +@guvectorize( + [ + "void(float32[:], float32[:], float32[:])", + "void(float64[:], float64[:], float64[:])", + ], + "(),(n)->()", +) +def _crps_ensemble_nrg_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit CRPS estimator based on the energy form.""" + obs = obs[0] + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + e_1 = 0.0 + e_2 = 0.0 + M_eff = 0 + + for i in range(M): + if np.isnan(fct[i]): + continue + M_eff += 1 + e_1 += abs(fct[i] - obs) + for j in range(i + 1, M): + if np.isnan(fct[j]): + continue + e_2 += 2 * abs(fct[j] - fct[i]) + + if M_eff == 0: + out[0] = np.nan + else: + out[0] = e_1 / M_eff - 0.5 * e_2 / (M_eff**2) + + +@guvectorize( + [ + "void(float32[:], float32[:], float32[:])", + "void(float64[:], float64[:], float64[:])", + ], + "(),(n)->()", +) +def _crps_ensemble_fair_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit fair CRPS estimator based on the energy form.""" + obs = obs[0] + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + e_1 = 0.0 + e_2 = 0.0 + M_eff = 0 + + for i in range(M): + if np.isnan(fct[i]): + continue + M_eff += 1 + e_1 += abs(fct[i] - obs) + for j in range(i + 1, M): + if np.isnan(fct[j]): + continue + e_2 += 2 * abs(fct[j] - fct[i]) + + if M_eff <= 1: + out[0] = np.nan + else: + out[0] = e_1 / M_eff - 0.5 * e_2 / (M_eff * (M_eff - 1)) + + +@guvectorize( + [ + "void(float32[:], float32[:], float32[:])", + "void(float64[:], float64[:], float64[:])", + ], + "(),(n)->()", +) +def _crps_ensemble_qd_nanomit_gufunc(obs: np.ndarray, fct: np.ndarray, out: np.ndarray): + """NaN-omit CRPS estimator based on the quantile decomposition form. + + Assumes the ensemble is pre-sorted with NaN members at the tail. + """ + obs = obs[0] + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + # Count valid members — NaNs are at the tail after sorting. + M_eff = 0 + for i in range(M): + if np.isnan(fct[i]): + break + M_eff += 1 + + if M_eff == 0: + out[0] = np.nan + return + + obs_cdf = 0.0 + integral = 0.0 + + for i in range(M_eff): + forecast = fct[i] + if obs < forecast: + obs_cdf = 1.0 + integral += (forecast - obs) * (M_eff * obs_cdf - (i + 1) + 0.5) + + out[0] = (2 / M_eff**2) * integral + + +@guvectorize("(),(n)->()") +def _crps_ensemble_pwm_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit CRPS estimator based on the probability weighted moment (PWM) form. + + Assumes the ensemble is pre-sorted with NaN members at the tail. + """ + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + # Count valid members — NaNs are at the tail after sorting. + M_eff = 0 + for i in range(M): + if np.isnan(fct[i]): + break + M_eff += 1 + + if M_eff == 0: + out[0] = np.nan + return + + expected_diff = 0.0 + β_0 = 0.0 + β_1 = 0.0 + + for i in range(M_eff): + forecast = fct[i] + expected_diff += np.abs(forecast - obs) + β_0 += forecast + β_1 += forecast * i + + if M_eff == 1: + out[0] = expected_diff / M_eff + β_0 / M_eff + else: + out[0] = expected_diff / M_eff + β_0 / M_eff - 2 * β_1 / (M_eff * (M_eff - 1)) + + +@guvectorize("(),(n)->()") +def _crps_ensemble_int_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit CRPS estimator based on the integral form. + + Assumes the ensemble is pre-sorted with NaN members at the tail. + Unlike the base variant, this skips NaN members rather than stopping at them. + """ + M = fct.shape[0] + + if np.isnan(obs): + out[0] = np.nan + return + + # Count valid members — NaNs are at the tail after sorting. + M_eff = 0 + for i in range(M): + if np.isnan(fct[i]): + break + M_eff += 1 + + if M_eff == 0: + out[0] = np.nan + return + + obs_cdf = 0 + forecast_cdf = 0.0 + prev_forecast = 0.0 + integral = 0.0 + + for n in range(M_eff): + forecast = fct[n] + + if obs_cdf == 0 and obs < forecast: + integral += (obs - prev_forecast) * forecast_cdf**2 + integral += (forecast - obs) * (forecast_cdf - 1) ** 2 + obs_cdf = 1 + else: + integral += (forecast_cdf - obs_cdf) ** 2 * (forecast - prev_forecast) + + forecast_cdf += 1 / M_eff + prev_forecast = forecast + + if obs_cdf == 0: + integral += obs - fct[M_eff - 1] + + out[0] = integral + + +@guvectorize("(),(n)->()") +def _crps_ensemble_akr_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit CRPS estimator based on the approximate kernel representation.""" + M = fct.shape[-1] + e_1 = 0.0 + e_2 = 0.0 + M_eff = 0 + first_valid_val = 0.0 + prev_valid_val = 0.0 + + for i in range(M): + if np.isnan(fct[i]): + continue + M_eff += 1 + e_1 += abs(fct[i] - obs) + if M_eff == 1: + first_valid_val = fct[i] + else: + e_2 += abs(fct[i] - prev_valid_val) + prev_valid_val = fct[i] + + # Circular wrap-around: pair last valid with first valid. + if M_eff >= 2: + e_2 += abs(first_valid_val - prev_valid_val) + + if M_eff == 0: + out[0] = np.nan + else: + out[0] = e_1 / M_eff - 0.5 / M_eff * e_2 + + +@guvectorize("(),(n)->()") +def _crps_ensemble_akr_circperm_nanomit_gufunc( + obs: np.ndarray, fct: np.ndarray, out: np.ndarray +): + """NaN-omit CRPS estimator based on the AKR with cyclic permutation.""" + M = fct.shape[-1] + e_1 = 0.0 + e_2 = 0.0 + M_eff = 0 + + # First pass: count valid members. + for i in range(M): + if not np.isnan(fct[i]): + M_eff += 1 + + if M_eff == 0: + out[0] = np.nan + return + + # Second pass: compute e_1 and e_2 using rank within valid members. + i_eff = 0 + for i in range(M): + if np.isnan(fct[i]): + continue + e_1 += abs(fct[i] - obs) + sigma_i_eff = int((i_eff + 1 + ((M_eff - 1) / 2)) % M_eff) + # Find the sigma_i_eff-th valid member. + count = 0 + for j in range(M): + if np.isnan(fct[j]): + continue + if count == sigma_i_eff: + e_2 += abs(fct[i] - fct[j]) + break + count += 1 + i_eff += 1 + + out[0] = e_1 / M_eff - 0.5 / M_eff * e_2 + + +@guvectorize("(),(n),(),(n)->()") +def _owcrps_ensemble_nrg_nanomit_gufunc( + obs: np.ndarray, + fct: np.ndarray, + ow: np.ndarray, + fw: np.ndarray, + out: np.ndarray, +): + """NaN-omit outcome-weighted CRPS estimator based on the energy form.""" + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + e_1 = 0.0 + e_2 = 0.0 + sum_fw = 0.0 + M_eff = 0 + + for i in range(M): + if np.isnan(fct[i]): + continue + M_eff += 1 + e_1 += abs(fct[i] - obs) * fw[i] * ow + sum_fw += fw[i] + for j in range(i + 1, M): + if np.isnan(fct[j]): + continue + e_2 += 2 * abs(fct[i] - fct[j]) * fw[i] * fw[j] * ow + + if M_eff == 0: + out[0] = np.nan + return + + wbar = sum_fw / M_eff + out[0] = e_1 / (M_eff * wbar) - 0.5 * e_2 / ((M_eff * wbar) ** 2) + + +@guvectorize("(),(n),(),(n)->()") +def _vrcrps_ensemble_nrg_nanomit_gufunc( + obs: np.ndarray, + fct: np.ndarray, + ow: np.ndarray, + fw: np.ndarray, + out: np.ndarray, +): + """NaN-omit vertically re-scaled CRPS estimator based on the energy form.""" + M = fct.shape[-1] + + if np.isnan(obs): + out[0] = np.nan + return + + e_1 = 0.0 + e_2 = 0.0 + sum_fw = 0.0 + sum_abs_fw = 0.0 + M_eff = 0 + + for i in range(M): + if np.isnan(fct[i]): + continue + M_eff += 1 + e_1 += abs(fct[i] - obs) * fw[i] * ow + sum_fw += fw[i] + sum_abs_fw += abs(fct[i]) * fw[i] + for j in range(i + 1, M): + if np.isnan(fct[j]): + continue + e_2 += 2 * abs(fct[i] - fct[j]) * fw[i] * fw[j] + + if M_eff == 0: + out[0] = np.nan + return + + wbar = sum_fw / M_eff + wabs_x = sum_abs_fw / M_eff + wabs_y = abs(obs) * ow + out[0] = e_1 / M_eff - 0.5 * e_2 / (M_eff**2) + (wabs_x - wabs_y) * (wbar - ow) + + +estimator_gufuncs_nanomit = { + "akr_circperm": lazy_gufunc_wrapper_uv(_crps_ensemble_akr_circperm_nanomit_gufunc), + "akr": lazy_gufunc_wrapper_uv(_crps_ensemble_akr_nanomit_gufunc), + "fair": _crps_ensemble_fair_nanomit_gufunc, + "int": lazy_gufunc_wrapper_uv(_crps_ensemble_int_nanomit_gufunc), + "nrg": _crps_ensemble_nrg_nanomit_gufunc, + "pwm": lazy_gufunc_wrapper_uv(_crps_ensemble_pwm_nanomit_gufunc), + "qd": _crps_ensemble_qd_nanomit_gufunc, + "ownrg": lazy_gufunc_wrapper_uv(_owcrps_ensemble_nrg_nanomit_gufunc), + "vrnrg": lazy_gufunc_wrapper_uv(_vrcrps_ensemble_nrg_nanomit_gufunc), +} + __all__ = [ "_crps_ensemble_akr_circperm_gufunc", "_crps_ensemble_akr_gufunc", diff --git a/scoringrules/core/typing.py b/scoringrules/core/typing.py index c87c6eb..dd5f53a 100644 --- a/scoringrules/core/typing.py +++ b/scoringrules/core/typing.py @@ -11,3 +11,4 @@ ArrayLike = tp.TypeVar("ArrayLike", bound=_array | float | int) Backend = tp.Literal["numpy", "numba", "jax", "torch", "tensorflow"] | None + NanPolicy = tp.Literal["propagate", "omit", "raise"] diff --git a/scoringrules/core/utils.py b/scoringrules/core/utils.py index 47450f9..ebd4d40 100644 --- a/scoringrules/core/utils.py +++ b/scoringrules/core/utils.py @@ -155,6 +155,69 @@ def univariate_sort_ens(fct, estimator=None, sorted_ensemble=False, backend=None return fct +def nan_policy_check(nan_policy: str) -> None: + """Validate the nan_policy argument.""" + valid = ("propagate", "omit", "raise") + if nan_policy not in valid: + raise ValueError(f"Invalid nan_policy '{nan_policy}'. Must be one of {valid}.") + + +def apply_nan_policy_ens_uv(obs, fct, nan_policy="propagate", backend=None): + """Apply NaN policy to univariate ensemble forecasts (fct shape: ..., M). + + For 'propagate': no-op, returns (obs, fct, None). + For 'raise': raises ValueError if any NaN in fct or obs. + For 'omit': returns (obs, fct_zeroed, nan_mask) where nan_mask is a boolean + array (True where fct member is NaN) and NaN members are replaced with 0.0. + """ + B = backends.active if backend is None else backends[backend] + + if nan_policy == "raise": + if B.any(B.isnan(fct)) or B.any(B.isnan(obs)): + raise ValueError( + "NaN values encountered in input. " + "Use nan_policy='propagate' or nan_policy='omit' to handle NaN values." + ) + return obs, fct, None + + if nan_policy == "omit": + nan_mask = B.isnan(fct) + fct = B.where(nan_mask, B.asarray(0.0), fct) + return obs, fct, nan_mask + + # propagate + return obs, fct, None + + +def apply_nan_policy_ens_mv(obs, fct, nan_policy="propagate", backend=None): + """Apply NaN policy to multivariate ensemble forecasts (fct shape: ..., M, D). + + A NaN in any variable of an ensemble member marks the entire member as invalid. + + For 'propagate': no-op, returns (obs, fct, None). + For 'raise': raises ValueError if any NaN in fct or obs. + For 'omit': returns (obs, fct_zeroed, nan_mask) where nan_mask has shape + (..., M) — True for invalid members — and NaN members are replaced with 0.0. + """ + B = backends.active if backend is None else backends[backend] + + if nan_policy == "raise": + if B.any(B.isnan(fct)) or B.any(B.isnan(obs)): + raise ValueError( + "NaN values encountered in input. " + "Use nan_policy='propagate' or nan_policy='omit' to handle NaN values." + ) + return obs, fct, None + + if nan_policy == "omit": + nan_mask = B.any(B.isnan(fct), axis=-1) # shape (..., M) + fct = B.where(nan_mask[..., None], B.asarray(0.0), fct) + return obs, fct, nan_mask + + # propagate + return obs, fct, None + + def lazy_gufunc_wrapper_uv(func): """ Wrapper for lazy/dynamic generalized universal functions so diff --git a/tests/test_crps_nan_policy.py b/tests/test_crps_nan_policy.py new file mode 100644 index 0000000..d1e9b1a --- /dev/null +++ b/tests/test_crps_nan_policy.py @@ -0,0 +1,390 @@ +"""Tests for nan_policy parameter in CRPS ensemble scoring rules. + +Covers all three policies ('propagate', 'omit', 'raise') across all CRPS estimators +and the weighted variants (twcrps, owcrps, vrcrps). + +Note on akr/akr_circperm with nan_policy='omit' on array backends: + The array backend uses a roll-based masking approach for the circular kernel, + which masks out pairs involving NaN members but also loses the wrap-around pair + when NaN sits at the tail. This means the array-backend result does NOT match + the clean-ensemble result (a known limitation). The numba gufuncs correctly + reconnect valid neighbours and DO match the clean ensemble. Tests that require + a match against the clean ensemble therefore only run for the numba backend on + akr/akr_circperm. +""" + +import numpy as np +import pytest +import scoringrules as sr + +# All estimators available for crps_ensemble +ESTIMATORS = ["nrg", "fair", "pwm", "qd", "int", "akr", "akr_circperm"] + +# Estimators whose omit behaviour matches the clean-ensemble result on all backends. +# 'int' is excluded because nan_policy='omit' raises NotImplementedError on non-numba. +# 'akr'/'akr_circperm' are excluded because the array backend uses a different +# approximation (see module docstring above). +ESTIMATORS_OMIT_ALL_BACKENDS = ["nrg", "fair", "pwm", "qd"] + +# Same but also excludes 'fair' for edge-case tests where M_eff may equal 1 +# (fair's denominator M_eff*(M_eff-1) is 0 when M_eff==1 → returns NaN by design). +ESTIMATORS_OMIT_MEFF_GE2 = ["nrg", "pwm", "qd"] + +# Reference inputs: five valid members, with and without NaN interspersed. +# The valid members are the same in both arrays (and already sorted). +_OBS = np.float64(0.5) +_FCT_CLEAN = np.array([0.1, 0.3, 0.7, 0.9, 1.1]) +_FCT_WITH_NAN = np.array([0.1, np.nan, 0.3, np.nan, 0.7, 0.9, 1.1]) +# NaN members of _FCT_WITH_NAN replaced by _OBS — same shape, no NaN, used for +# building 2-row batches where one row is clean and the other contains NaN. +_FCT_WITH_NAN_FILLED = np.where(np.isnan(_FCT_WITH_NAN), _OBS, _FCT_WITH_NAN) + + +# --------------------------------------------------------------------------- +# propagate policy +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("estimator", ESTIMATORS) +def test_propagate_returns_nan(estimator, backend): + """nan_policy='propagate' returns NaN when the ensemble contains NaN members.""" + if backend == "numba" and estimator == "int": + # The existing int gufunc stops at the first NaN (sorted to the tail) and + # returns a partial integral rather than NaN — pre-existing behaviour outside + # the scope of this PR. + pytest.skip( + "int gufunc stops early at NaN (pre-existing behaviour) rather than " + "propagating NaN; propagation for int/numba is not a goal of this PR." + ) + res = sr.crps_ensemble( + _OBS, + _FCT_WITH_NAN, + estimator=estimator, + nan_policy="propagate", + backend=backend, + ) + assert np.isnan(float(np.asarray(res))) + + +@pytest.mark.parametrize("estimator", ESTIMATORS) +def test_propagate_is_default(estimator, backend): + """Calling without nan_policy is identical to nan_policy='propagate'.""" + if backend == "numba" and estimator == "int": + pytest.skip("int/numba pre-existing NaN propagation behaviour — see above.") + res_explicit = sr.crps_ensemble( + _OBS, + _FCT_WITH_NAN, + estimator=estimator, + nan_policy="propagate", + backend=backend, + ) + res_default = sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator=estimator, backend=backend + ) + assert np.array_equal( + np.asarray(res_explicit), np.asarray(res_default), equal_nan=True + ) + + +# --------------------------------------------------------------------------- +# omit policy — basic sanity +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "estimator", ESTIMATORS_OMIT_ALL_BACKENDS + ["akr", "akr_circperm"] +) +def test_omit_ignores_nans(estimator, backend): + """nan_policy='omit' returns a finite score when the ensemble has NaN members.""" + res = sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator=estimator, nan_policy="omit", backend=backend + ) + assert np.isfinite(float(np.asarray(res))) + + +@pytest.mark.parametrize( + "estimator", ESTIMATORS_OMIT_ALL_BACKENDS + ["akr", "akr_circperm"] +) +def test_omit_all_nan_returns_nan(estimator, backend): + """nan_policy='omit' returns NaN when every ensemble member is NaN.""" + fct_all_nan = np.array([np.nan, np.nan, np.nan]) + # silence warnings (divide by zero, invalid value encountered) + with np.errstate(divide="ignore", invalid="ignore"): + res = sr.crps_ensemble( + _OBS, fct_all_nan, estimator=estimator, nan_policy="omit", backend=backend + ) + assert np.isnan(float(np.asarray(res))) + + +# --------------------------------------------------------------------------- +# omit policy — correctness: result must match the clean ensemble +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("estimator", ESTIMATORS_OMIT_ALL_BACKENDS) +def test_omit_matches_manual(estimator, backend): + """nan_policy='omit' matches computing the score on the clean (NaN-free) ensemble.""" + res_omit = sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator=estimator, nan_policy="omit", backend=backend + ) + # sorted_ensemble=True avoids re-sorting the already-sorted clean array. + res_clean = sr.crps_ensemble( + _OBS, _FCT_CLEAN, estimator=estimator, sorted_ensemble=True, backend=backend + ) + assert np.isclose( + float(np.asarray(res_omit)), float(np.asarray(res_clean)), rtol=1e-5 + ) + + +@pytest.mark.parametrize("estimator", ["akr", "akr_circperm"]) +def test_omit_matches_manual_akr_numba(estimator, backend): + """akr/akr_circperm nanomit gufuncs match the clean ensemble on numba.""" + if backend != "numba": + pytest.skip( + "Array-backend akr/akr_circperm uses a roll-based masking that loses " + "valid consecutive pairs when NaN members are interspersed, so the result " + "differs from the clean ensemble — known limitation." + ) + res_omit = sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator=estimator, nan_policy="omit", backend=backend + ) + res_clean = sr.crps_ensemble(_OBS, _FCT_CLEAN, estimator=estimator, backend=backend) + assert np.isclose( + float(np.asarray(res_omit)), float(np.asarray(res_clean)), rtol=1e-5 + ) + + +def test_omit_int_numba_matches_manual(backend): + """int estimator with nan_policy='omit' works on numba and matches the clean ensemble.""" + if backend != "numba": + pytest.skip("int + omit only supported on numba") + res_omit = sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator="int", nan_policy="omit", backend=backend + ) + res_clean = sr.crps_ensemble( + _OBS, _FCT_CLEAN, estimator="int", sorted_ensemble=True, backend=backend + ) + assert np.isclose( + float(np.asarray(res_omit)), float(np.asarray(res_clean)), rtol=1e-5 + ) + + +def test_omit_int_non_numba_raises(backend): + """int estimator with nan_policy='omit' raises NotImplementedError on non-numba.""" + if backend == "numba": + pytest.skip("int + omit is supported on numba") + with pytest.raises(NotImplementedError): + sr.crps_ensemble( + _OBS, _FCT_WITH_NAN, estimator="int", nan_policy="omit", backend=backend + ) + + +# --------------------------------------------------------------------------- +# omit policy — edge cases +# --------------------------------------------------------------------------- + + +def test_omit_nan_in_obs(backend): + """nan_policy='omit' still returns NaN when obs itself is NaN.""" + res = sr.crps_ensemble( + np.float64(np.nan), _FCT_CLEAN, nan_policy="omit", backend=backend + ) + assert np.isnan(float(np.asarray(res))) + + +def test_omit_fair_meff_one_returns_nan(backend): + """fair estimator with nan_policy='omit' returns NaN when only one valid member remains.""" + fct_one_valid = np.array([_FCT_CLEAN[0], np.nan, np.nan]) + res = sr.crps_ensemble( + _OBS, fct_one_valid, estimator="fair", nan_policy="omit", backend=backend + ) + assert np.isnan(float(np.asarray(res))) + + +@pytest.mark.parametrize("estimator", ESTIMATORS_OMIT_ALL_BACKENDS) +def test_omit_varying_nan_counts(estimator, backend): + """Batched forecasts each with a different number of NaN members.""" + obs = np.zeros(4) + fct = np.array( + [ + [1.0, 2.0, 3.0, np.nan, np.nan], # 3 valid + [0.5, 1.5, 2.5, np.nan, np.nan], # 3 valid (different values) + [0.5, 1.5, np.nan, np.nan, np.nan], # 2 valid + [1.0, 2.0, 3.0, 4.0, 5.0], # 5 valid, no NaN + ] + ) + res = np.asarray( + sr.crps_ensemble( + obs, fct, estimator=estimator, nan_policy="omit", backend=backend + ) + ) + assert res.shape == (4,) + # Rows with ≥ 2 valid members should always be finite (even for 'fair'). + assert np.all(np.isfinite(res)) + + # Row 3 (no NaN) should match the clean computation on the same backend. + res_clean = float( + np.asarray( + sr.crps_ensemble( + 0.0, + np.array([1.0, 2.0, 3.0, 4.0, 5.0]), + estimator=estimator, + backend=backend, + ) + ) + ) + assert np.isclose(float(res[3]), res_clean, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# raise policy +# --------------------------------------------------------------------------- + + +def test_raise_with_nans(backend): + """nan_policy='raise' raises ValueError when NaN values are present.""" + with pytest.raises(ValueError): + sr.crps_ensemble(_OBS, _FCT_WITH_NAN, nan_policy="raise", backend=backend) + + +def test_raise_without_nans(backend): + """nan_policy='raise' computes normally when no NaN values are present.""" + res = sr.crps_ensemble(_OBS, _FCT_CLEAN, nan_policy="raise", backend=backend) + assert np.isfinite(float(np.asarray(res))) + + +def test_raise_with_nan_in_obs(backend): + """nan_policy='raise' raises ValueError when obs itself is NaN.""" + with pytest.raises(ValueError): + sr.crps_ensemble( + np.float64(np.nan), _FCT_CLEAN, nan_policy="raise", backend=backend + ) + + +# --------------------------------------------------------------------------- +# policy validation +# --------------------------------------------------------------------------- + + +def test_invalid_nan_policy(backend): + """An unrecognised nan_policy string raises ValueError.""" + with pytest.raises(ValueError): + sr.crps_ensemble(_OBS, _FCT_CLEAN, nan_policy="invalid_policy", backend=backend) + + +# --------------------------------------------------------------------------- +# all three policies agree when there are no NaN values +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "estimator", ESTIMATORS_OMIT_ALL_BACKENDS + ["akr", "akr_circperm"] +) +def test_no_nans_all_policies_equal(estimator, backend): + """All three policies produce identical results when the ensemble has no NaN values.""" + res_prop = sr.crps_ensemble( + _OBS, _FCT_CLEAN, estimator=estimator, nan_policy="propagate", backend=backend + ) + res_omit = sr.crps_ensemble( + _OBS, _FCT_CLEAN, estimator=estimator, nan_policy="omit", backend=backend + ) + res_raise = sr.crps_ensemble( + _OBS, _FCT_CLEAN, estimator=estimator, nan_policy="raise", backend=backend + ) + assert np.isclose(float(np.asarray(res_prop)), float(np.asarray(res_omit))) + assert np.isclose(float(np.asarray(res_prop)), float(np.asarray(res_raise))) + + +# --------------------------------------------------------------------------- +# Weighted variants +# --------------------------------------------------------------------------- + + +def test_twcrps_propagate_returns_nan(backend): + """twcrps_ensemble with nan_policy='propagate' returns NaN for NaN ensemble.""" + res = sr.twcrps_ensemble( + _OBS, _FCT_WITH_NAN, nan_policy="propagate", backend=backend + ) + assert np.isnan(float(np.asarray(res))) + + +def test_twcrps_omit_finite(backend): + """twcrps_ensemble with nan_policy='omit' returns a finite score.""" + res = sr.twcrps_ensemble(_OBS, _FCT_WITH_NAN, nan_policy="omit", backend=backend) + assert np.isfinite(float(np.asarray(res))) + + +def test_twcrps_no_nans_policies_equal(backend): + """twcrps_ensemble: all policies agree when no NaN values are present.""" + res_prop = sr.twcrps_ensemble( + _OBS, _FCT_CLEAN, nan_policy="propagate", backend=backend + ) + res_omit = sr.twcrps_ensemble(_OBS, _FCT_CLEAN, nan_policy="omit", backend=backend) + assert np.isclose(float(np.asarray(res_prop)), float(np.asarray(res_omit))) + + +def test_twcrps_raise_with_nans(backend): + """twcrps_ensemble with nan_policy='raise' raises ValueError when NaN is present.""" + with pytest.raises(ValueError): + sr.twcrps_ensemble(_OBS, _FCT_WITH_NAN, nan_policy="raise", backend=backend) + + +def test_owcrps_propagate_returns_nan(backend): + """owcrps_ensemble with nan_policy='propagate' returns NaN for NaN ensemble.""" + res = sr.owcrps_ensemble( + _OBS, _FCT_WITH_NAN, nan_policy="propagate", backend=backend + ) + assert np.isnan(float(np.asarray(res))) + + +def test_owcrps_omit_matches_manual(backend): + """owcrps_ensemble with nan_policy='omit' matches computing on the clean ensemble.""" + res_omit = sr.owcrps_ensemble( + _OBS, _FCT_WITH_NAN, nan_policy="omit", backend=backend + ) + res_clean = sr.owcrps_ensemble(_OBS, _FCT_CLEAN, backend=backend) + assert np.isclose( + float(np.asarray(res_omit)), float(np.asarray(res_clean)), rtol=1e-5 + ) + + +def test_owcrps_raise_with_nans(backend): + """owcrps_ensemble with nan_policy='raise' raises ValueError.""" + with pytest.raises(ValueError): + sr.owcrps_ensemble(_OBS, _FCT_WITH_NAN, nan_policy="raise", backend=backend) + + +def test_vrcrps_propagate_returns_nan(backend): + """vrcrps_ensemble with nan_policy='propagate' returns NaN for NaN ensemble.""" + # Use batched inputs — vr_ensemble has a pre-existing axis=1 bug with scalar obs. + obs = np.array([float(_OBS), float(_OBS)]) + fct = np.array([_FCT_WITH_NAN, _FCT_WITH_NAN_FILLED]) + res = np.asarray( + sr.vrcrps_ensemble(obs, fct, nan_policy="propagate", backend=backend) + ) + assert np.isnan(res[0]) # NaN ensemble → NaN + assert np.isfinite(res[1]) # clean ensemble → finite + + +def test_vrcrps_omit_matches_manual(backend): + """vrcrps_ensemble with nan_policy='omit' matches computing on the clean ensemble.""" + # Use batched inputs — vr_ensemble has a pre-existing axis=1 bug with scalar obs. + obs = np.array([float(_OBS)]) + res_omit = np.asarray( + sr.vrcrps_ensemble( + obs, _FCT_WITH_NAN[np.newaxis, :], nan_policy="omit", backend=backend + ) + ) + res_clean = np.asarray( + sr.vrcrps_ensemble(obs, _FCT_CLEAN[np.newaxis, :], backend=backend) + ) + assert np.isclose(float(res_omit[0]), float(res_clean[0]), rtol=1e-5) + + +def test_vrcrps_raise_with_nans(backend): + """vrcrps_ensemble with nan_policy='raise' raises ValueError.""" + obs = np.array([float(_OBS)]) + with pytest.raises(ValueError): + sr.vrcrps_ensemble( + obs, _FCT_WITH_NAN[np.newaxis, :], nan_policy="raise", backend=backend + )