Skip to content

Commit ad172ae

Browse files
committed
fix sign issue
1 parent def50a4 commit ad172ae

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

autoemulate/core/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ def __call__(
351351
# Ensure 2D y_true for consistent handling
352352
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
353353

354-
# Compute mean log loss
355-
mean_log_loss = y_pred.log_prob(y_true).mean()
354+
# Compute mean negative log likelihood
355+
mean_log_loss = -y_pred.log_prob(y_true).mean()
356356

357357
# If no training data, return mean log loss
358358
if metric_params.y_train is None:
@@ -368,7 +368,7 @@ def __call__(
368368
# Avoid numerical issues
369369
y_train_var = torch.clamp(y_train_var, min=1e-6)
370370

371-
# Compute mean log prob under trivial Gaussian model
371+
# Compute mean negative log likelihood under trivial Gaussian model
372372
mean_trivial_log_loss = (
373373
0.5
374374
* (

0 commit comments

Comments
 (0)