Skip to content

Commit def50a4

Browse files
committed
fix shape issue
1 parent 0abb6f2 commit def50a4

File tree

2 files changed

+12
-28
lines changed

2 files changed

+12
-28
lines changed

autoemulate/core/metrics.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def __call__(
325325
y_true: TensorLike
326326
True target values.
327327
metric_params: MetricParams
328-
Metric parameters including: y_train.
328+
Metric parameters including: y_train and reduction.
329329
330330
Returns
331331
-------
@@ -352,24 +352,18 @@ def __call__(
352352
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
353353

354354
# Compute mean log loss
355-
mean_log_loss = y_pred.log_prob(y_true).mean(dim=0)
355+
mean_log_loss = y_pred.log_prob(y_true).mean()
356356

357357
# If no training data, return mean log loss
358358
if metric_params.y_train is None:
359359
return mean_log_loss
360360

361361
# Ensure 2D y_train for consistent handling
362-
y_train = (
363-
metric_params.y_train.unsqueeze(-1)
364-
if metric_params.y_train.ndim == 1
365-
else metric_params.y_train
366-
)
367-
368-
y_train_mean = y_train.mean()
362+
y_train_mean = metric_params.y_train.mean(dim=0, keepdim=True).view(1, -1)
369363

370364
# following GPyTorch implementation, use global variance rather than per task
371365
# https://github.com/cornellius-gp/gpytorch/blob/c0fb6c64311fdbef2862fd3ba2bd613fbd081e79/gpytorch/metrics/metrics.py#L60
372-
y_train_var = y_train.var()
366+
y_train_var = metric_params.y_train.var()
373367

374368
# Avoid numerical issues
375369
y_train_var = torch.clamp(y_train_var, min=1e-6)

tests/core/test_metrics.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -557,14 +557,9 @@ def test_msll_with_1d_inputs():
557557
)
558558

559559
msll = MSLL(
560-
y_pred, y_true, metric_params=MetricParams(y_train=y_train, reduction="none")
561-
)
562-
563-
assert isinstance(msll, torch.Tensor)
564-
assert msll.shape == torch.Size([1])
565-
566-
msll = MSLL(
567-
y_pred, y_true, metric_params=MetricParams(y_train=y_train, reduction="mean")
560+
y_pred,
561+
y_true,
562+
metric_params=MetricParams(y_train=y_train),
568563
)
569564

570565
assert isinstance(msll, torch.Tensor)
@@ -579,14 +574,9 @@ def test_msll_with_2d_inputs():
579574
y_pred = Normal(loc=y_true, scale=torch.ones_like(y_true))
580575

581576
msll = MSLL(
582-
y_pred, y_true, metric_params=MetricParams(y_train=y_train, reduction="none")
583-
)
584-
585-
assert isinstance(msll, torch.Tensor)
586-
assert msll.shape == torch.Size([n_outputs])
587-
588-
msll = MSLL(
589-
y_pred, y_true, metric_params=MetricParams(y_train=y_train, reduction="mean")
577+
y_pred,
578+
y_true,
579+
metric_params=MetricParams(y_train=y_train),
590580
)
591581

592582
assert isinstance(msll, torch.Tensor)
@@ -607,7 +597,7 @@ def test_msll_raises_for_non_distribution():
607597
def test_msll_perfect_prediction_with_training_data():
608598
"""Test MSLL when predictions perfectly match true values."""
609599
y_true = torch.tensor([1.0, 2.0, 3.0]).view(-1, 1)
610-
y_train = torch.tensor([0.5, 1.5, 2.5]).view(-1, 1)
600+
y_train = torch.tensor([0.5, 1.5, 2.5])
611601

612602
# Perfect prediction: distribution centered at true values with small variance
613603
y_pred = Normal(loc=y_true, scale=torch.ones_like(y_true) * 0.01)
@@ -621,7 +611,7 @@ def test_msll_perfect_prediction_with_training_data():
621611
def test_msll_poor_prediction_with_training_data():
622612
"""Test MSLL when predictions are far from true values."""
623613
y_true = torch.tensor([1.0, 2.0, 3.0]).view(-1, 1)
624-
y_train = torch.tensor([1.0, 2.0, 3.0]).view(-1, 1)
614+
y_train = torch.tensor([1.0, 2.0, 3.0])
625615

626616
# Poor prediction: distribution centered far from true values
627617
y_pred = Normal(

0 commit comments

Comments
 (0)