Skip to content

Commit 067c3b4

Browse files
authored
Merge pull request #883 from alan-turing-institute/517-add-option-to-choose-metrics
adding general torchmetrics support
2 parents 0f25b70 + 5f71801 commit 067c3b4

File tree

12 files changed

+1142
-227
lines changed

12 files changed

+1142
-227
lines changed

autoemulate/core/compare.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
from autoemulate.core.device import TorchDeviceMixin
1414
from autoemulate.core.logging_config import get_configured_logger
15-
from autoemulate.core.model_selection import bootstrap, evaluate, r2_metric
15+
from autoemulate.core.metrics import (
16+
TorchMetrics,
17+
get_metric_config,
18+
get_metric_configs,
19+
)
20+
from autoemulate.core.model_selection import bootstrap, evaluate
1621
from autoemulate.core.plotting import (
1722
calculate_subplot_layout,
1823
create_and_plot_slice,
@@ -72,6 +77,8 @@ def __init__(
7277
device: DeviceLike | None = None,
7378
random_seed: int | None = None,
7479
log_level: str = "progress_bar",
80+
tuning_metric: str | TorchMetrics = "r2",
81+
evaluation_metrics: list[str | TorchMetrics] | None = None,
7582
):
7683
"""
7784
Initialize the AutoEmulate class.
@@ -122,13 +129,27 @@ def __init__(
122129
it will show a progress bar during model comparison. It will set the
123130
logging level to "error" to avoid cluttering the output
124131
with debug/info logs.
132+
tuning_metric: str | TorchMetrics
133+
Metric to use for hyperparameter tuning. Can be a string shortcut
134+
("r2", "rmse", "mse", "mae") or a MetricConfig object. Defaults to "r2".
135+
evaluation_metrics: list[str | TorchMetrics] | None
136+
Metrics to compute during evaluation.
137+
If None, then defaults to ["r2", "rmse"].
138+
Each entry can be a string shortcut or a MetricConfig object.
139+
IMPORTANT: The first metric in the list is used to
140+
determine the best model.
125141
"""
126142
Results.__init__(self)
127143
self.random_seed = random_seed
128144
TorchDeviceMixin.__init__(self, device=device)
129145
x, y = self._convert_to_tensors(x, y)
130146
x, y = self._move_tensors_to_device(x, y)
131147

148+
# Setup metrics. If evaluation_metrics is None, default to ["r2", "rmse"]
149+
evaluation_metrics = evaluation_metrics or ["r2", "rmse"]
150+
self.evaluation_metrics = get_metric_configs(evaluation_metrics)
151+
self.tuning_metric = get_metric_config(tuning_metric)
152+
132153
# Transforms to search over
133154
self.x_transforms_list = [
134155
self.get_transforms(transforms)
@@ -323,18 +344,19 @@ def log_compare(
323344
x_transforms,
324345
y_transforms,
325346
best_params_for_this_model,
326-
r2_score,
327-
rmse_score,
347+
test_metrics,
328348
):
329349
"""Log the comparison results."""
350+
metrics_str = ", ".join(
351+
f"{metric}: {mean:.3f}" for metric, (mean, _std) in test_metrics.items()
352+
)
330353
msg = (
331354
"Comparison results:\n"
332355
f"Best Model: {best_model_name}, "
333356
f"x transforms: {x_transforms}, "
334-
f"y transforms: {y_transforms}",
357+
f"y transforms: {y_transforms}, "
335358
f"Best params: {best_params_for_this_model}, "
336-
f"R2 score: {r2_score:.3f}, "
337-
f"RMSE score: {rmse_score:.3f}",
359+
f"Metrics: {metrics_str}"
338360
)
339361
self.logger.debug(msg)
340362

@@ -351,7 +373,13 @@ def compare(self):
351373
- Log the results.
352374
- Save the best model and its parameters.
353375
"""
354-
tuner = Tuner(self.train_val, y=None, n_iter=self.n_iter, device=self.device)
376+
tuner = Tuner(
377+
self.train_val,
378+
y=None,
379+
n_iter=self.n_iter,
380+
device=self.device,
381+
tuning_metric=self.tuning_metric,
382+
)
355383
self.logger.info(
356384
"Comparing %s", [model_cls.__name__ for model_cls in self.models]
357385
)
@@ -393,7 +421,11 @@ def compare(self):
393421
mean_scores = [
394422
np.mean(score).item() for score in scores
395423
]
396-
best_score_idx = np.argmax(mean_scores)
424+
# Select best whether we're maximizing or minimizing
425+
if self.tuning_metric.maximize:
426+
best_score_idx = np.argmax(mean_scores)
427+
else:
428+
best_score_idx = np.argmin(mean_scores)
397429
best_params_for_this_model = params_list[best_score_idx]
398430
self.logger.debug(
399431
'Tuner found best params for model "%s": '
@@ -445,35 +477,33 @@ def compare(self):
445477
# This can fail for some model params
446478
transformed_emulator.fit(train_val_x, train_val_y)
447479

448-
(
449-
(r2_train_val, r2_train_val_std),
450-
(rmse_train_val, rmse_train_val_std),
451-
) = bootstrap(
480+
train_metrics = bootstrap(
452481
transformed_emulator,
453482
train_val_x,
454483
train_val_y,
455484
n_bootstraps=self.n_bootstraps,
456485
device=self.device,
486+
metrics=self.evaluation_metrics,
457487
)
458-
(r2_test, r2_test_std), (rmse_test, rmse_test_std) = (
459-
bootstrap(
460-
transformed_emulator,
461-
test_x,
462-
test_y,
463-
n_bootstraps=self.n_bootstraps,
464-
device=self.device,
465-
)
488+
test_metrics = bootstrap(
489+
transformed_emulator,
490+
test_x,
491+
test_y,
492+
n_bootstraps=self.n_bootstraps,
493+
device=self.device,
494+
metrics=self.evaluation_metrics,
466495
)
467496

497+
# Log all test metrics from test_metrics dictionary
498+
test_metrics_str = ", ".join(
499+
f"{metric}: {mean:.3f} (std: {std:.3f})"
500+
for metric, (mean, std) in test_metrics.items()
501+
)
468502
self.logger.debug(
469-
'Cross-validation for model "%s"'
470-
" completed with test mean (std) R2 score: %.3f (%.3f),"
471-
" mean (std) RMSE score: %.3f (%.3f)",
503+
'Cross-validation for model "%s" '
504+
"completed with test metrics: %s",
472505
model_cls.__name__,
473-
r2_test,
474-
r2_test_std,
475-
rmse_test,
476-
rmse_test_std,
506+
test_metrics_str,
477507
)
478508
self.logger.info(
479509
"Finished running Model: %s\n", model_cls.__name__
@@ -483,14 +513,8 @@ def compare(self):
483513
model_name=transformed_emulator.untransformed_model_name,
484514
model=transformed_emulator,
485515
params=best_params_for_this_model,
486-
r2_test=r2_test,
487-
rmse_test=rmse_test,
488-
r2_test_std=r2_test_std,
489-
rmse_test_std=rmse_test_std,
490-
r2_train=r2_train_val,
491-
rmse_train=rmse_train_val,
492-
r2_train_std=r2_train_val_std,
493-
rmse_train_std=rmse_train_val_std,
516+
test_metrics=test_metrics,
517+
train_metrics=train_metrics,
494518
)
495519
self.add_result(result)
496520
# if successful, break out of the retry loop
@@ -511,14 +535,17 @@ def compare(self):
511535
)
512536

513537
# Get the best result and log the comparison
514-
best_result = self.best_result()
538+
# Use the first evaluation metric to determine the best result
539+
first_metric = self.evaluation_metrics[0]
540+
best_result = self.best_result(
541+
metric_name=first_metric.name,
542+
)
515543
self.log_compare(
516544
best_model_name=best_result.model_name,
517545
x_transforms=best_result.x_transforms,
518546
y_transforms=best_result.y_transforms,
519547
best_params_for_this_model=best_result.params,
520-
r2_score=best_result.r2_test,
521-
rmse_score=best_result.rmse_test,
548+
test_metrics=best_result.test_metrics,
522549
)
523550

524551
def fit_from_reinitialized(
@@ -642,7 +669,7 @@ def plot( # noqa: PLR0912, PLR0915
642669

643670
# Re-run prediction with just this model to get the predictions
644671
y_pred, y_variance = model.predict_mean_and_variance(test_x)
645-
r2_score = evaluate(y_pred, test_y, r2_metric())
672+
r2_score = evaluate(y_pred, test_y)
646673

647674
# Handle ranges
648675
input_ranges = input_ranges or {}

autoemulate/core/metrics.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Metrics configuration and utilities for model evaluation and tuning."""
2+
3+
from __future__ import annotations
4+
5+
from abc import abstractmethod
6+
from collections.abc import Sequence
7+
from functools import partial
8+
9+
import torchmetrics
10+
11+
from autoemulate.core.types import OutputLike, TensorLike, TorchMetricsLike
12+
13+
14+
class Metric:
15+
"""Configuration for a single metric.
16+
17+
Parameters
18+
----------
19+
name : str
20+
Display name for the metric.
21+
maximize : bool
22+
Whether higher values are better. Defaults to True.
23+
"""
24+
25+
name: str
26+
maximize: bool
27+
28+
def __repr__(self) -> str:
29+
"""Return the string representation of the MetricConfig."""
30+
return f"MetricConfig(name={self.name}, maximize={self.maximize})"
31+
32+
@abstractmethod
33+
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
34+
"""Calculate metric."""
35+
36+
37+
class TorchMetrics(Metric):
38+
"""Configuration for a single torchmetrics metric.
39+
40+
Parameters
41+
----------
42+
metric : MetricLike
43+
The torchmetrics metric class or partial.
44+
name : str
45+
Display name for the metric. If None, uses the class name of the metric.
46+
maximize : bool
47+
Whether higher values are better.
48+
"""
49+
50+
def __init__(
51+
self,
52+
metric: TorchMetricsLike,
53+
name: str,
54+
maximize: bool,
55+
):
56+
self.metric = metric
57+
self.name = name
58+
self.maximize = maximize
59+
60+
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
61+
"""Calculate metric."""
62+
if not isinstance(y_pred, TensorLike):
63+
raise ValueError(f"Metric not implemented for y_pred ({type(y_pred)})")
64+
if not isinstance(y_true, TensorLike):
65+
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")
66+
67+
metric = self.metric()
68+
metric.to(y_pred.device)
69+
# Assume first dim is a batch dim, flatten others for metric calculation
70+
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))
71+
return metric.compute()
72+
73+
74+
R2 = TorchMetrics(
75+
metric=torchmetrics.R2Score,
76+
name="r2",
77+
maximize=True,
78+
)
79+
80+
RMSE = TorchMetrics(
81+
metric=partial(torchmetrics.MeanSquaredError, squared=False),
82+
name="rmse",
83+
maximize=False,
84+
)
85+
86+
MSE = TorchMetrics(
87+
metric=torchmetrics.MeanSquaredError,
88+
name="mse",
89+
maximize=False,
90+
)
91+
92+
MAE = TorchMetrics(
93+
metric=torchmetrics.MeanAbsoluteError,
94+
name="mae",
95+
maximize=False,
96+
)
97+
98+
AVAILABLE_METRICS = {
99+
"r2": R2,
100+
"rmse": RMSE,
101+
"mse": MSE,
102+
"mae": MAE,
103+
}
104+
105+
106+
def get_metric_config(
107+
metric: str | TorchMetrics,
108+
) -> TorchMetrics:
109+
"""Convert various metric specifications to MetricConfig.
110+
111+
Parameters
112+
----------
113+
metric : str | type[torchmetrics.Metric] | partial[torchmetrics.Metric] | Metric
114+
The metric specification. Can be:
115+
- A string shortcut like "r2", "rmse", "mse", "mae"
116+
- A Metric instance (returned as-is)
117+
118+
Returns
119+
-------
120+
TorchMetrics
121+
The metric configuration.
122+
123+
Raises
124+
------
125+
ValueError
126+
If the metric specification is invalid or name is not provided when required.
127+
128+
129+
"""
130+
# If already a TorchMetric, return as-is
131+
if isinstance(metric, TorchMetrics):
132+
return metric
133+
134+
if isinstance(metric, str):
135+
if metric.lower() in AVAILABLE_METRICS:
136+
return AVAILABLE_METRICS[metric.lower()]
137+
raise ValueError(
138+
f"Unknown metric shortcut '{metric}'. "
139+
f"Available options: {list(AVAILABLE_METRICS.keys())}"
140+
)
141+
# Handle unsupported types
142+
raise ValueError(
143+
f"Unsupported metric type: {type(metric).__name__}. "
144+
"Metric must be a string shortcut or a MetricConfig instance."
145+
)
146+
147+
148+
def get_metric_configs(
149+
metrics: Sequence[str | TorchMetrics],
150+
) -> list[TorchMetrics]:
151+
"""Convert a list of metric specifications to MetricConfig objects.
152+
153+
Parameters
154+
----------
155+
metrics : Sequence[str | TorchMetrics]
156+
Sequence of metric specifications.
157+
158+
Returns
159+
-------
160+
list[TorchMetrics]
161+
List of metric configurations.
162+
"""
163+
result_metrics = []
164+
165+
for m in metrics:
166+
config = get_metric_config(m) if isinstance(m, (str | TorchMetrics)) else m
167+
result_metrics.append(config)
168+
169+
return result_metrics

0 commit comments

Comments
 (0)