Skip to content

Commit 20093a1

Browse files
authored
Merge pull request #912 from alan-turing-institute/911-update-get-metric-config
Update API for getting metrics (#911)
2 parents bf210f5 + 66d0677 commit 20093a1

File tree

5 files changed

+59
-70
lines changed

5 files changed

+59
-70
lines changed

autoemulate/core/compare.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from autoemulate.core.metrics import (
1717
R2,
1818
TorchMetrics,
19-
get_metric_config,
20-
get_metric_configs,
19+
get_metric,
20+
get_metrics,
2121
)
2222
from autoemulate.core.model_selection import bootstrap, evaluate
2323
from autoemulate.core.plotting import (
@@ -144,8 +144,8 @@ def __init__(
144144

145145
# Setup metrics. If evaluation_metrics is None, default to ["r2", "rmse"]
146146
evaluation_metrics = evaluation_metrics or ["r2", "rmse"]
147-
self.evaluation_metrics = get_metric_configs(evaluation_metrics)
148-
self.tuning_metric = get_metric_config(tuning_metric)
147+
self.evaluation_metrics = get_metrics(evaluation_metrics)
148+
self.tuning_metric = get_metric(tuning_metric)
149149

150150
# Transforms to search over
151151
self.x_transforms_list = [

autoemulate/core/metrics.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ class TorchMetrics(Metric):
4848
4949
Parameters
5050
----------
51-
metric : MetricLike
51+
metric: MetricLike
5252
The torchmetrics metric class or partial.
53-
name : str
53+
name: str
5454
Display name for the metric. If None, uses the class name of the metric.
55-
maximize : bool
55+
maximize: bool
5656
Whether higher values are better.
5757
"""
5858

@@ -83,8 +83,12 @@ def __call__(
8383
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
8484
metric = self.metric()
8585
metric.to(y_pred.device)
86-
# Assume first dim is a batch dim, flatten others for metric calculation
87-
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))
86+
87+
# Assume first dim is a batch dim if >=2D, flatten others for metric calculation
88+
metric.update(
89+
y_pred.flatten(start_dim=1) if y_pred.ndim > 1 else y_pred,
90+
y_true.flatten(start_dim=1) if y_true.ndim > 1 else y_true,
91+
)
8892
return metric.compute()
8993

9094

@@ -244,32 +248,30 @@ def __call__(
244248
}
245249

246250

247-
def get_metric_config(
248-
metric: str | TorchMetrics,
249-
) -> TorchMetrics:
250-
"""Convert various metric specifications to MetricConfig.
251+
def get_metric(metric: str | Metric) -> Metric:
252+
"""Convert metric specification to a `Metric`.
251253
252254
Parameters
253255
----------
254-
metric : str | type[torchmetrics.Metric] | partial[torchmetrics.Metric] | Metric
256+
metric: str | Metric
255257
The metric specification. Can be:
256258
- A string shortcut like "r2", "rmse", "mse", "mae"
257259
- A Metric instance (returned as-is)
258260
259261
Returns
260262
-------
261-
TorchMetrics
262-
The metric configuration.
263+
Metric
264+
The metric.
263265
264266
Raises
265267
------
266268
ValueError
267-
If the metric specification is invalid or name is not provided when required.
268-
269+
If the metric specification is not a string (and registered in
270+
AVAILABLE_METRICS) or Metric instance.
269271
270272
"""
271-
# If already a TorchMetric, return as-is
272-
if isinstance(metric, TorchMetrics):
273+
# If already a Metric, return as-is
274+
if isinstance(metric, Metric):
273275
return metric
274276

275277
if isinstance(metric, str):
@@ -286,25 +288,17 @@ def get_metric_config(
286288
)
287289

288290

289-
def get_metric_configs(
290-
metrics: Sequence[str | TorchMetrics],
291-
) -> list[TorchMetrics]:
292-
"""Convert a list of metric specifications to MetricConfig objects.
291+
def get_metrics(metrics: Sequence[str | Metric]) -> list[Metric]:
292+
"""Convert a list of metric specifications to list of `Metric`s.
293293
294294
Parameters
295295
----------
296-
metrics : Sequence[str | TorchMetrics]
296+
metrics: Sequence[str | Metric]
297297
Sequence of metric specifications.
298298
299299
Returns
300300
-------
301-
list[TorchMetrics]
302-
List of metric configurations.
301+
list[Metric]
302+
List of metrics.
303303
"""
304-
result_metrics = []
305-
306-
for m in metrics:
307-
config = get_metric_config(m) if isinstance(m, (str | TorchMetrics)) else m
308-
result_metrics.append(config)
309-
310-
return result_metrics
304+
return [get_metric(m) for m in metrics]

autoemulate/core/model_selection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
get_torch_device,
1111
move_tensors_to_device,
1212
)
13-
from autoemulate.core.metrics import R2, Metric, TorchMetrics, get_metric_configs
13+
from autoemulate.core.metrics import R2, Metric, get_metrics
1414
from autoemulate.core.types import (
1515
DeviceLike,
1616
ModelParams,
@@ -63,7 +63,7 @@ def cross_validate(
6363
y_transforms: list[Transform] | None = None,
6464
device: DeviceLike = "cpu",
6565
random_seed: int | None = None,
66-
metrics: list[TorchMetrics] | None = None,
66+
metrics: list[Metric] | None = None,
6767
):
6868
"""
6969
Cross validate model performance using the given `cv` strategy.
@@ -100,7 +100,7 @@ def cross_validate(
100100

101101
# Setup metrics
102102
if metrics is None:
103-
metrics = get_metric_configs(["r2", "rmse"])
103+
metrics = get_metrics(["r2", "rmse"])
104104

105105
cv_results = {metric.name: [] for metric in metrics}
106106
device = get_torch_device(device)
@@ -158,7 +158,7 @@ def bootstrap(
158158
n_bootstraps: int | None = 100,
159159
n_samples: int = 1000,
160160
device: str | torch.device = "cpu",
161-
metrics: list[TorchMetrics] | None = None,
161+
metrics: list[Metric] | None = None,
162162
) -> dict[str, tuple[float, float]]:
163163
"""
164164
Get bootstrap estimates of metrics.
@@ -193,7 +193,7 @@ def bootstrap(
193193

194194
# Setup metrics
195195
if metrics is None:
196-
metrics = get_metric_configs(["r2", "rmse"])
196+
metrics = get_metrics(["r2", "rmse"])
197197

198198
# If no bootstraps are specified, fall back to a single evaluation on given data
199199
if n_bootstraps is None:

autoemulate/core/tuner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.distributions import Transform
77

88
from autoemulate.core.device import TorchDeviceMixin
9-
from autoemulate.core.metrics import TorchMetrics, get_metric_config
9+
from autoemulate.core.metrics import Metric, get_metric
1010
from autoemulate.core.model_selection import cross_validate
1111
from autoemulate.core.types import (
1212
DeviceLike,
@@ -46,7 +46,7 @@ def __init__(
4646
n_iter: int = 10,
4747
device: DeviceLike | None = None,
4848
random_seed: int | None = None,
49-
tuning_metric: str | TorchMetrics = "r2",
49+
tuning_metric: str | Metric = "r2",
5050
):
5151
TorchDeviceMixin.__init__(self, device=device)
5252
self.n_iter = n_iter
@@ -60,7 +60,7 @@ def __init__(
6060
self.dataset = self._convert_to_dataset(x_tensor, y_tensor)
6161

6262
# Setup tuning metric
63-
self.tuning_metric = get_metric_config(tuning_metric)
63+
self.tuning_metric = get_metric(tuning_metric)
6464

6565
if random_seed is not None:
6666
set_random_seed(seed=random_seed)

tests/core/test_metrics.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
CRPSMetric,
1616
Metric,
1717
TorchMetrics,
18-
get_metric_config,
19-
get_metric_configs,
18+
get_metric,
19+
get_metrics,
2020
)
2121
from torch.distributions import Normal
2222

@@ -176,7 +176,7 @@ def test_mae_computation():
176176

177177
def test_get_metric_config_with_string_r2():
178178
"""Test get_metric_config with 'r2' string."""
179-
config = get_metric_config("r2")
179+
config = get_metric("r2")
180180

181181
assert config == R2
182182
assert config.name == "r2"
@@ -185,7 +185,7 @@ def test_get_metric_config_with_string_r2():
185185

186186
def test_get_metric_config_with_string_rmse():
187187
"""Test get_metric_config with 'rmse' string."""
188-
config = get_metric_config("rmse")
188+
config = get_metric("rmse")
189189

190190
assert config == RMSE
191191
assert config.name == "rmse"
@@ -194,7 +194,7 @@ def test_get_metric_config_with_string_rmse():
194194

195195
def test_get_metric_config_with_string_mse():
196196
"""Test get_metric_config with 'mse' string."""
197-
config = get_metric_config("mse")
197+
config = get_metric("mse")
198198

199199
assert config == MSE
200200
assert config.name == "mse"
@@ -203,7 +203,7 @@ def test_get_metric_config_with_string_mse():
203203

204204
def test_get_metric_config_with_string_mae():
205205
"""Test get_metric_config with 'mae' string."""
206-
config = get_metric_config("mae")
206+
config = get_metric("mae")
207207

208208
assert config == MAE
209209
assert config.name == "mae"
@@ -212,9 +212,9 @@ def test_get_metric_config_with_string_mae():
212212

213213
def test_get_metric_config_case_insensitive():
214214
"""Test get_metric_config is case insensitive."""
215-
config_upper = get_metric_config("R2")
216-
config_lower = get_metric_config("r2")
217-
config_mixed = get_metric_config("R2")
215+
config_upper = get_metric("R2")
216+
config_lower = get_metric("r2")
217+
config_mixed = get_metric("R2")
218218

219219
assert config_upper == config_lower == config_mixed == R2
220220

@@ -225,7 +225,7 @@ def test_get_metric_config_with_torchmetrics_instance():
225225
metric=torchmetrics.R2Score, name="custom_r2", maximize=True
226226
)
227227

228-
config = get_metric_config(custom_metric)
228+
config = get_metric(custom_metric)
229229

230230
assert config == custom_metric
231231
assert config.name == "custom_r2"
@@ -234,7 +234,7 @@ def test_get_metric_config_with_torchmetrics_instance():
234234
def test_get_metric_config_invalid_string():
235235
"""Test get_metric_config with invalid string raises ValueError."""
236236
with pytest.raises(ValueError, match="Unknown metric shortcut") as excinfo:
237-
get_metric_config("invalid_metric")
237+
get_metric("invalid_metric")
238238

239239
assert "Unknown metric shortcut" in str(excinfo.value)
240240
assert "invalid_metric" in str(excinfo.value)
@@ -244,15 +244,15 @@ def test_get_metric_config_invalid_string():
244244
def test_get_metric_config_unsupported_type():
245245
"""Test get_metric_config with unsupported type raises ValueError."""
246246
with pytest.raises(ValueError, match="Unsupported metric type") as excinfo:
247-
get_metric_config(123) # type: ignore[arg-type]
247+
get_metric(123) # type: ignore[arg-type]
248248

249249
assert "Unsupported metric type" in str(excinfo.value)
250250

251251

252252
def test_get_metric_config_with_none():
253253
"""Test get_metric_config with None raises ValueError."""
254254
with pytest.raises(ValueError, match="Unsupported metric type") as excinfo:
255-
get_metric_config(None) # type: ignore[arg-type]
255+
get_metric(None) # type: ignore[arg-type]
256256

257257
assert "Unsupported metric type" in str(excinfo.value)
258258

@@ -263,7 +263,7 @@ def test_get_metric_config_with_none():
263263
def test_get_metric_configs_with_strings():
264264
"""Test get_metric_configs with list of strings."""
265265
metrics = ["r2", "rmse", "mse"]
266-
configs = get_metric_configs(metrics)
266+
configs = get_metrics(metrics)
267267

268268
assert len(configs) == 3
269269
assert configs[0] == R2
@@ -278,7 +278,7 @@ def test_get_metric_configs_with_mixed_types():
278278
)
279279

280280
metrics = ["r2", custom_metric, "mse"]
281-
configs = get_metric_configs(metrics)
281+
configs = get_metrics(metrics)
282282

283283
assert len(configs) == 3
284284
assert configs[0] == R2
@@ -288,15 +288,15 @@ def test_get_metric_configs_with_mixed_types():
288288

289289
def test_get_metric_configs_with_empty_list():
290290
"""Test get_metric_configs with empty list."""
291-
configs = get_metric_configs([])
291+
configs = get_metrics([])
292292

293293
assert len(configs) == 0
294294
assert configs == []
295295

296296

297297
def test_get_metric_configs_with_single_metric():
298298
"""Test get_metric_configs with single metric."""
299-
configs = get_metric_configs(["r2"])
299+
configs = get_metrics(["r2"])
300300

301301
assert len(configs) == 1
302302
assert configs[0] == R2
@@ -305,7 +305,7 @@ def test_get_metric_configs_with_single_metric():
305305
def test_get_metric_configs_with_all_available_metrics():
306306
"""Test get_metric_configs with all available metrics."""
307307
metrics = list(AVAILABLE_METRICS.keys())
308-
configs = get_metric_configs(metrics)
308+
configs = get_metrics(metrics)
309309

310310
assert len(configs) == len(AVAILABLE_METRICS)
311311

@@ -320,7 +320,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
320320
metric=torchmetrics.MeanSquaredError, name="mse_1", maximize=False
321321
)
322322

323-
configs = get_metric_configs([metric1, metric2])
323+
configs = get_metrics([metric1, metric2])
324324

325325
assert len(configs) == 2
326326
assert configs[0] == metric1
@@ -330,7 +330,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
330330
def test_get_metric_configs_case_insensitive():
331331
"""Test get_metric_configs is case insensitive for strings."""
332332
metrics = ["R2", "RMSE", "mse", "MaE", "Crps"]
333-
configs = get_metric_configs(metrics)
333+
configs = get_metrics(metrics)
334334

335335
assert len(configs) == 5
336336
assert configs[0] == R2
@@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors():
396396
def test_metric_configs_workflow():
397397
"""Test complete workflow of getting and using metric configs."""
398398
# Get configs from strings
399-
configs = get_metric_configs(["r2", "rmse"])
399+
metrics = get_metrics(["r2", "rmse"])
400400

401401
# Use configs to compute metrics
402402
y_pred = torch.tensor([1.0, 2.0, 3.0])
403403
y_true = torch.tensor([1.0, 2.0, 3.0])
404404

405-
results = {}
406-
for config in configs:
407-
metric = config.metric()
408-
metric.update(y_pred, y_true)
409-
results[config.name] = metric.compute()
410-
405+
results = {metric.name: metric(y_pred, y_true) for metric in metrics}
411406
assert "r2" in results
412407
assert "rmse" in results
413408
assert torch.isclose(results["r2"], torch.tensor(1.0)) # Perfect R2
@@ -509,7 +504,7 @@ def test_crps_aggregation_across_batch():
509504

510505
def test_get_metric_config_crps():
511506
"""Test get_metric_config with 'crps' string."""
512-
config = get_metric_config("crps")
507+
config = get_metric("crps")
513508

514509
assert config == CRPS
515510
assert isinstance(config, CRPSMetric)

0 commit comments

Comments
 (0)