File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -373,7 +373,7 @@ def __call__(
373373 "distributions."
374374 )
375375 raise ValueError (msg )
376- return model_nll_output .reshape (* y_true .shape )
376+ return model_nll_output .reshape (* y_true .shape [ 1 :] )
377377 msg = (
378378 f"Unknown reduction method: { metric_params .reduction } . "
379379 "Expected 'mean' or 'none'."
@@ -403,7 +403,7 @@ def __call__(
403403 if model_nll_output is None :
404404 msg = "Per-output MLL not available for non-Independent distributions."
405405 raise ValueError (msg )
406- return (model_nll_output - trivial_nll_output ).reshape (* y_true .shape )
406+ return (model_nll_output - trivial_nll_output ).reshape (* y_true .shape [ 1 :] )
407407 msg = (
408408 f"Unknown reduction method: { metric_params .reduction } . "
409409 "Expected 'mean' or 'none'."
You can’t perform that action at this time.
0 commit comments