@@ -287,13 +287,14 @@ class MSLLMetric(ProbabilisticMetric):
287287 MSLL evaluates the quality of probabilistic predictions by measuring the
288288 log-likelihood of the true values under the predictive distribution,
289289 standardized by the log-likelihood under the trivial model (i.e., predictive
290- distribution parameterized with the data mean and variance).
290+ normal distribution parameterized with the data mean and variance).
291291
292292 If no training data is supplied, the mean log loss is computed.
293293
294294 Lower MSLL values indicate better predictive performance.
295295
296- Note: This metric requires probabilistic predictions.
296+ Note: This metric requires probabilistic predictions. Standardization
297+ assumes that the predictive distribution is Gaussian.
297298
298299 Attributes
299300 ----------
@@ -324,7 +325,7 @@ def __call__(
324325 y_true: TensorLike
325326 True target values.
326327 metric_params: MetricParams
327- Metric parameters including: n_samples, y_train, reduction .
328+ Metric parameters including: y_train.
328329
329330 Returns
330331 -------
@@ -350,37 +351,12 @@ def __call__(
350351 # Ensure 2D y_true for consistent handling
351352 y_true = y_true .unsqueeze (- 1 ) if y_true .ndim == 1 else y_true
352353
353- # Handle distributions without mean/variance attributes
354- try :
355- y_pred_mean , y_pred_var = y_pred .mean , y_pred .variance
356- except Exception :
357- y_pred_samples = y_pred .rsample (torch .Size ([metric_params .n_samples ]))
358- y_pred_mean = y_pred_samples .mean (dim = 0 )
359- y_pred_var = y_pred_samples .var (dim = 0 )
360-
361- if y_pred_mean .shape != y_true .shape :
362- raise ValueError (
363- f"Predictions shape { y_pred_mean .shape } does not match "
364- f"y_true shape { y_true .shape } ."
365- )
366-
367354 # Compute mean log loss
368- mean_log_loss = (
369- 0.5 * torch .log (2 * torch .pi * y_pred_var )
370- + torch .square (y_true - y_pred_mean ) / (2 * y_pred_var )
371- ).mean (dim = 0 )
355+ mean_log_loss = y_pred .log_prob (y_true ).mean (dim = 0 )
372356
373357 # If no training data, return mean log loss
374358 if metric_params .y_train is None :
375- if metric_params .reduction == "none" :
376- return mean_log_loss
377- if metric_params .reduction == "mean" :
378- return mean_log_loss .mean ()
379- msg = (
380- f"Invalid reduction '{ metric_params .reduction } '. "
381- "Expected 'mean' or 'none'."
382- )
383- raise ValueError (msg )
359+ return mean_log_loss
384360
385361 # Ensure 2D y_train for consistent handling
386362 y_train = (
@@ -389,7 +365,8 @@ def __call__(
389365 else metric_params .y_train
390366 )
391367
392- y_train_mean = y_train .mean (dim = 0 )
368+ y_train_mean = y_train .mean ()
369+
393370 # following GPyTorch implementation, use global variance rather than per task
394371 # https://github.com/cornellius-gp/gpytorch/blob/c0fb6c64311fdbef2862fd3ba2bd613fbd081e79/gpytorch/metrics/metrics.py#L60
395372 y_train_var = y_train .var ()
@@ -398,20 +375,16 @@ def __call__(
398375 y_train_var = torch .clamp (y_train_var , min = 1e-6 )
399376
400377 # Compute mean log prob under trivial Gaussian model
401- mean_trivial_log_loss = 0.5 * (
402- torch .log (2 * torch .pi * y_train_var )
403- + torch .square (y_true - y_train_mean ) / (2 * y_train_var )
404- ).mean (dim = 0 )
378+ mean_trivial_log_loss = (
379+ 0.5
380+ * (
381+ torch .log (2 * torch .pi * y_train_var )
382+ + torch .square (y_true - y_train_mean ) / (2 * y_train_var )
383+ ).mean ()
384+ )
405385
406386 # Return mean standardized log loss
407- if metric_params .reduction == "none" :
408- return mean_log_loss - mean_trivial_log_loss
409- if metric_params .reduction == "mean" :
410- return (mean_log_loss - mean_trivial_log_loss ).mean ()
411- msg = (
412- f"Invalid reduction '{ metric_params .reduction } '. Expected 'mean' or 'none'."
413- )
414- raise ValueError (msg )
387+ return mean_log_loss - mean_trivial_log_loss
415388
416389
417390R2 = TorchMetrics (
0 commit comments