1111import torch
1212import torchmetrics
1313from einops import rearrange
14+ from torch .distributions import Independent
1415from torchmetrics .regression .crps import ContinuousRankedProbabilityScore
1516
1617from autoemulate .core .types import (
@@ -351,12 +352,32 @@ def __call__(
351352 # Ensure 2D y_true for consistent handling
352353 y_true = y_true .unsqueeze (- 1 ) if y_true .ndim == 1 else y_true
353354
354- # Compute mean negative log likelihood
355- mean_log_loss = - y_pred .log_prob (y_true ).mean ()
355+ # Compute mean negative log likelihood (also by output dimension if have
356+ # Independent distribution)
357+ if isinstance (y_pred , Independent ):
358+ model_nll_output = - y_pred .base_dist .log_prob (y_true ).mean (dim = 0 )
359+ model_nll_total = model_nll_output .mean ()
360+ else :
361+ model_nll_output = None
362+ model_nll_total = - y_pred .log_prob (y_true ).mean ()
356363
357364 # If no training data, return mean log loss
358365 if metric_params .y_train is None :
359- return mean_log_loss
366+ if metric_params .reduction == "mean" :
367+ return model_nll_total
368+ if metric_params .reduction == "none" :
369+ if model_nll_output is None :
370+ msg = (
371+ "Per-output MLL not available for non-Independent "
372+ "distributions."
373+ )
374+ raise ValueError (msg )
375+ return model_nll_output
376+ msg = (
377+ f"Unknown reduction method: { metric_params .reduction } . "
378+ "Expected 'mean' or 'none'."
379+ )
380+ raise ValueError (msg )
360381
361382 # Ensure 2D y_train for consistent handling
362383 y_train_mean = metric_params .y_train .mean (dim = 0 , keepdim = True ).view (1 , - 1 )
@@ -369,16 +390,24 @@ def __call__(
369390 y_train_var = torch .clamp (y_train_var , min = 1e-6 )
370391
371392 # Compute mean negative log likelihood under trivial Gaussian model
372- mean_trivial_log_loss = (
373- 0.5
374- * (
375- torch .log (2 * torch .pi * y_train_var )
376- + torch .square (y_true - y_train_mean ) / (2 * y_train_var )
377- ).mean ()
378- )
393+ trivial_nll_output = 0.5 * (
394+ torch .log (2 * torch .pi * y_train_var )
395+ + torch .square (y_true - y_train_mean ) / (2 * y_train_var )
396+ ).mean (dim = 0 )
379397
380398 # Return mean standardized log loss
381- return mean_log_loss - mean_trivial_log_loss
399+ if metric_params .reduction == "mean" :
400+ return model_nll_total - trivial_nll_output .mean ()
401+ if metric_params .reduction == "none" :
402+ if model_nll_output is None :
403+ msg = "Per-output MLL not available for non-Independent distributions."
404+ raise ValueError (msg )
405+ return model_nll_output - trivial_nll_output
406+ msg = (
407+ f"Unknown reduction method: { metric_params .reduction } . "
408+ "Expected 'mean' or 'none'."
409+ )
410+ raise ValueError (msg )
382411
383412
384413R2 = TorchMetrics (
0 commit comments