Skip to content

Commit c565b0a

Browse files
Revert change in aggregate_all function
1 parent 19c9147 commit c565b0a

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

mapie/aggregation_functions.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,8 @@ def aggregate_all(agg_function: Optional[str], X: NDArray) -> NDArray:
112112
array([14.5, 14.5])
113113
114114
"""
115-
row_nan_mask = np.isnan(X).all(axis=1)
116-
result = np.full(X.shape[0], np.nan)
117-
if agg_function == "mean":
118-
result[~row_nan_mask] = np.nanmean(X[~row_nan_mask], axis=1)
119-
return result
120-
elif agg_function == "median":
121-
result[~row_nan_mask] = np.nanmedian(X[~row_nan_mask], axis=1)
122-
return result
115+
if agg_function == "median":
116+
return np.nanmedian(X, axis=1)
117+
elif agg_function == "mean":
118+
return np.nanmean(X, axis=1)
123119
raise ValueError("Aggregation function called but not defined.")

0 commit comments

Comments
 (0)