@@ -557,14 +557,9 @@ def test_msll_with_1d_inputs():
557557 )
558558
559559 msll = MSLL (
560- y_pred , y_true , metric_params = MetricParams (y_train = y_train , reduction = "none" )
561- )
562-
563- assert isinstance (msll , torch .Tensor )
564- assert msll .shape == torch .Size ([1 ])
565-
566- msll = MSLL (
567- y_pred , y_true , metric_params = MetricParams (y_train = y_train , reduction = "mean" )
560+ y_pred ,
561+ y_true ,
562+ metric_params = MetricParams (y_train = y_train ),
568563 )
569564
570565 assert isinstance (msll , torch .Tensor )
@@ -579,14 +574,9 @@ def test_msll_with_2d_inputs():
579574 y_pred = Normal (loc = y_true , scale = torch .ones_like (y_true ))
580575
581576 msll = MSLL (
582- y_pred , y_true , metric_params = MetricParams (y_train = y_train , reduction = "none" )
583- )
584-
585- assert isinstance (msll , torch .Tensor )
586- assert msll .shape == torch .Size ([n_outputs ])
587-
588- msll = MSLL (
589- y_pred , y_true , metric_params = MetricParams (y_train = y_train , reduction = "mean" )
577+ y_pred ,
578+ y_true ,
579+ metric_params = MetricParams (y_train = y_train ),
590580 )
591581
592582 assert isinstance (msll , torch .Tensor )
@@ -607,7 +597,7 @@ def test_msll_raises_for_non_distribution():
607597def test_msll_perfect_prediction_with_training_data ():
608598 """Test MSLL when predictions perfectly match true values."""
609599 y_true = torch .tensor ([1.0 , 2.0 , 3.0 ]).view (- 1 , 1 )
610- y_train = torch .tensor ([0.5 , 1.5 , 2.5 ]). view ( - 1 , 1 )
600+ y_train = torch .tensor ([0.5 , 1.5 , 2.5 ])
611601
612602 # Perfect prediction: distribution centered at true values with small variance
613603 y_pred = Normal (loc = y_true , scale = torch .ones_like (y_true ) * 0.01 )
@@ -621,7 +611,7 @@ def test_msll_perfect_prediction_with_training_data():
621611def test_msll_poor_prediction_with_training_data ():
622612 """Test MSLL when predictions are far from true values."""
623613 y_true = torch .tensor ([1.0 , 2.0 , 3.0 ]).view (- 1 , 1 )
624- y_train = torch .tensor ([1.0 , 2.0 , 3.0 ]). view ( - 1 , 1 )
614+ y_train = torch .tensor ([1.0 , 2.0 , 3.0 ])
625615
626616 # Poor prediction: distribution centered far from true values
627617 y_pred = Normal (
0 commit comments