Skip to content

Commit b7c5aa6

Browse files
authored
Merge pull request #913 from alan-turing-institute/882-custom-metrics
Enable custom metrics in AutoEmulate compare loop (#882)
2 parents da6eaf3 + 2b52df3 commit b7c5aa6

File tree

8 files changed

+222
-81
lines changed

8 files changed

+222
-81
lines changed

autoemulate/core/compare.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from autoemulate.core.logging_config import get_configured_logger
1616
from autoemulate.core.metrics import (
1717
R2,
18-
TorchMetrics,
18+
Metric,
1919
get_metric,
2020
get_metrics,
2121
)
@@ -74,8 +74,8 @@ def __init__(
7474
device: DeviceLike | None = None,
7575
random_seed: int | None = None,
7676
log_level: str = "progress_bar",
77-
tuning_metric: str | TorchMetrics = "r2",
78-
evaluation_metrics: list[str | TorchMetrics] | None = None,
77+
tuning_metric: str | Metric = "r2",
78+
evaluation_metrics: list[str | Metric] | None = None,
7979
):
8080
"""
8181
Initialize the AutoEmulate class.
@@ -542,9 +542,7 @@ def compare(self):
542542
# Get the best result and log the comparison
543543
# Use the first evaluation metric to determine the best result
544544
first_metric = self.evaluation_metrics[0]
545-
best_result = self.best_result(
546-
metric_name=first_metric.name,
547-
)
545+
best_result = self.best_result(first_metric)
548546
self.log_compare(
549547
best_model_name=best_result.model_name,
550548
x_transforms=best_result.x_transforms,

autoemulate/core/metrics.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from abc import abstractmethod
66
from collections.abc import Sequence
7-
from functools import partial
7+
from functools import partial, total_ordering
88

99
import torchmetrics
1010
from einops import rearrange
@@ -18,6 +18,7 @@
1818
)
1919

2020

21+
@total_ordering
2122
class Metric:
2223
"""Configuration for a single metric.
2324
@@ -33,9 +34,27 @@ class Metric:
3334
maximize: bool
3435

3536
def __repr__(self) -> str:
36-
"""Return the string representation of the Metric."""
37+
"""Representation of the Metric."""
3738
return f"Metric(name={self.name}, maximize={self.maximize})"
3839

40+
def __str__(self):
41+
"""Metric when formatted as a string."""
42+
return self.name
43+
44+
def __eq__(self, other: object) -> bool:
45+
"""Check equality based on metric name."""
46+
if not isinstance(other, Metric):
47+
return NotImplemented
48+
return self.name == other.name
49+
50+
def __hash__(self) -> int:
51+
"""Return hash based on metric name."""
52+
return hash(self.name)
53+
54+
def __lt__(self, other: Metric) -> bool:
55+
"""Compare metrics based on their str name."""
56+
return self.name < other.name
57+
3958
@abstractmethod
4059
def __call__(
4160
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000

autoemulate/core/model_selection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def bootstrap(
159159
n_samples: int = 1000,
160160
device: str | torch.device = "cpu",
161161
metrics: list[Metric] | None = None,
162-
) -> dict[str, tuple[float, float]]:
162+
) -> dict[Metric, tuple[float, float]]:
163163
"""
164164
Get bootstrap estimates of metrics.
165165
@@ -228,7 +228,7 @@ def bootstrap(
228228

229229
# Return mean and std for each metric
230230
return {
231-
metric.name: (
231+
metric: (
232232
metric_scores[metric.name].mean().item(),
233233
metric_scores[metric.name].std().item(),
234234
)

autoemulate/core/results.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pandas as pd
44

5-
from autoemulate.core.metrics import AVAILABLE_METRICS
5+
from autoemulate.core.metrics import Metric, get_metric
66
from autoemulate.core.types import ModelParams
77
from autoemulate.emulators.transformed.base import TransformedEmulator
88

@@ -18,8 +18,8 @@ def __init__(
1818
model_name: str,
1919
model: TransformedEmulator,
2020
params: ModelParams,
21-
test_metrics: dict[str, tuple[float, float]],
22-
train_metrics: dict[str, tuple[float, float]],
21+
test_metrics: dict[Metric, tuple[float, float]],
22+
train_metrics: dict[Metric, tuple[float, float]],
2323
):
2424
"""Initialize a Result object.
2525
@@ -141,32 +141,32 @@ def summarize(self) -> pd.DataFrame:
141141
"params": [result.params for result in self.results],
142142
}
143143

144-
# Collect all unique metric names from all results
144+
# Collect all unique metrics from all results
145145
all_test_metrics = set()
146146
all_train_metrics = set()
147147
for result in self.results:
148148
all_test_metrics.update(result.test_metrics.keys())
149149
all_train_metrics.update(result.train_metrics.keys())
150150

151151
# Add test metrics columns
152-
for metric_name in sorted(all_test_metrics):
153-
data[f"{metric_name}_test"] = [
154-
result.test_metrics.get(metric_name, (float("nan"), float("nan")))[0]
152+
for metric in sorted(all_test_metrics):
153+
data[f"{metric}_test"] = [
154+
result.test_metrics.get(metric, (float("nan"), float("nan")))[0]
155155
for result in self.results
156156
]
157-
data[f"{metric_name}_test_std"] = [
158-
result.test_metrics.get(metric_name, (float("nan"), float("nan")))[1]
157+
data[f"{metric}_test_std"] = [
158+
result.test_metrics.get(metric, (float("nan"), float("nan")))[1]
159159
for result in self.results
160160
]
161161

162162
# Add train metrics columns
163-
for metric_name in sorted(all_train_metrics):
164-
data[f"{metric_name}_train"] = [
165-
result.train_metrics.get(metric_name, (float("nan"), float("nan")))[0]
163+
for metric in sorted(all_train_metrics):
164+
data[f"{metric}_train"] = [
165+
result.train_metrics.get(metric, (float("nan"), float("nan")))[0]
166166
for result in self.results
167167
]
168-
data[f"{metric_name}_train_std"] = [
169-
result.train_metrics.get(metric_name, (float("nan"), float("nan")))[1]
168+
data[f"{metric}_train_std"] = [
169+
result.train_metrics.get(metric, (float("nan"), float("nan")))[1]
170170
for result in self.results
171171
]
172172

@@ -177,13 +177,13 @@ def summarize(self) -> pd.DataFrame:
177177

178178
summarise = summarize
179179

180-
def best_result(self, metric_name: str | None = None) -> Result:
180+
def best_result(self, metric: str | Metric | None = None) -> Result:
181181
"""
182182
Get the model with the best result based on the given metric.
183183
184184
Parameters
185185
----------
186-
metric_name: str | None
186+
metric: str | Metric | None
187187
The name of the metric to use for comparison. If None, uses the first
188188
available metric found in the results. The metric should exist in the
189189
test_metrics of the results.
@@ -202,51 +202,44 @@ def best_result(self, metric_name: str | None = None) -> Result:
202202
raise ValueError(msg)
203203

204204
# If metric_name is None, use the first available metric
205-
if metric_name is None:
205+
if metric is None:
206206
# Collect all available metrics
207-
available_metrics = set()
208-
for result in self.results:
209-
available_metrics.update(result.test_metrics.keys())
207+
available_metrics = [
208+
m for result in self.results for m in result.test_metrics
209+
]
210210

211211
if not available_metrics:
212212
msg = "No metrics available in results."
213213
raise ValueError(msg)
214214

215215
# Use the first metric
216-
metric_name = next(iter(available_metrics))
217-
logger.info("Using metric '%s' to determine best result.", metric_name)
216+
metric_selected = available_metrics[0]
217+
logger.info("Using metric '%s' to determine best result.", metric_selected)
218218
else:
219219
# Check if the specified metric exists in at least one result
220-
if not any(metric_name in result.test_metrics for result in self.results):
220+
if not any(metric in result.test_metrics for result in self.results):
221221
available_metrics = set()
222222
for result in self.results:
223223
available_metrics.update(result.test_metrics.keys())
224224
msg = (
225-
f"Metric '{metric_name}' not found in any results. "
225+
f"Metric '{metric}' not found in any results. "
226226
f"Available metrics: {sorted(available_metrics)}"
227227
)
228228
raise ValueError(msg)
229-
230-
logger.info("Using metric '%s' to determine best result.", metric_name)
231-
232-
# Determine if we are maximizing or minimizing the metric
233-
# from the metric name
234-
assert metric_name is not None # for pyright
235-
metric_config = AVAILABLE_METRICS.get(metric_name)
236-
if metric_config is None:
237-
msg = f"Metric '{metric_name}' not found in AVAILABLE_METRICS."
238-
raise ValueError(msg)
239-
metric_maximize = metric_config.maximize
229+
metric_selected = get_metric(metric)
230+
logger.info("Using metric '%s' to determine best result.", metric_selected)
240231

241232
# Select best result based on whether we're maximizing or minimizing
242-
if metric_maximize:
233+
if metric_selected.maximize:
243234
return max(
244235
self.results,
245-
key=lambda r: r.test_metrics.get(metric_name, (float("-inf"), 0))[0],
236+
key=lambda r: r.test_metrics.get(metric_selected, (float("-inf"), 0))[
237+
0
238+
],
246239
)
247240
return min(
248241
self.results,
249-
key=lambda r: r.test_metrics.get(metric_name, (float("inf"), 0))[0],
242+
key=lambda r: r.test_metrics.get(metric_selected, (float("inf"), 0))[0],
250243
)
251244

252245
def get_result(self, result_id: int) -> Result:

autoemulate/core/save.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import joblib
55
import pandas as pd
66

7+
from autoemulate.core.metrics import get_metric
78
from autoemulate.core.results import Result # , Results
89
from autoemulate.emulators.base import Emulator
910

@@ -150,13 +151,17 @@ def _load_result(self, path: str | Path) -> Result | Emulator:
150151
metric_name = col[:-5] # Remove "_test" suffix
151152
mean = row[col]
152153
std = row.get(f"{metric_name}_test_std", float("nan"))
153-
test_metrics[metric_name] = (mean, std)
154+
# Convert metric name string back to Metric object
155+
metric = get_metric(metric_name)
156+
test_metrics[metric] = (mean, std)
154157
elif col.endswith("_train") and not col.endswith("_train_std"):
155158
# Extract metric name (e.g., "r2" from "r2_train")
156159
metric_name = col[:-6] # Remove "_train" suffix
157160
mean = row[col]
158161
std = row.get(f"{metric_name}_train_std", float("nan"))
159-
train_metrics[metric_name] = (mean, std)
162+
# Convert metric name string back to Metric object
163+
metric = get_metric(metric_name)
164+
train_metrics[metric] = (mean, std)
160165

161166
return Result(
162167
id=row["id"],

0 commit comments

Comments
 (0)