Skip to content

Commit cb35677

Browse files
committed
fix metric output reshaping
don't use batch dim
1 parent cdc8111 commit cb35677

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

autoemulate/core/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __call__(
373373
"distributions."
374374
)
375375
raise ValueError(msg)
376-
return model_nll_output.reshape(*y_true.shape)
376+
return model_nll_output.reshape(*y_true.shape[1:])
377377
msg = (
378378
f"Unknown reduction method: {metric_params.reduction}. "
379379
"Expected 'mean' or 'none'."
@@ -403,7 +403,7 @@ def __call__(
403403
if model_nll_output is None:
404404
msg = "Per-output MLL not available for non-Independent distributions."
405405
raise ValueError(msg)
406-
return (model_nll_output - trivial_nll_output).reshape(*y_true.shape)
406+
return (model_nll_output - trivial_nll_output).reshape(*y_true.shape[1:])
407407
msg = (
408408
f"Unknown reduction method: {metric_params.reduction}. "
409409
"Expected 'mean' or 'none'."

0 commit comments

Comments
 (0)