1515 CRPSMetric ,
1616 Metric ,
1717 TorchMetrics ,
18- get_metric_config ,
19- get_metric_configs ,
18+ get_metric ,
19+ get_metrics ,
2020)
2121from torch .distributions import Normal
2222
@@ -176,7 +176,7 @@ def test_mae_computation():
176176
177177def test_get_metric_config_with_string_r2 ():
178178 """Test get_metric_config with 'r2' string."""
179- config = get_metric_config ("r2" )
179+ config = get_metric ("r2" )
180180
181181 assert config == R2
182182 assert config .name == "r2"
@@ -185,7 +185,7 @@ def test_get_metric_config_with_string_r2():
185185
186186def test_get_metric_config_with_string_rmse ():
187187 """Test get_metric_config with 'rmse' string."""
188- config = get_metric_config ("rmse" )
188+ config = get_metric ("rmse" )
189189
190190 assert config == RMSE
191191 assert config .name == "rmse"
@@ -194,7 +194,7 @@ def test_get_metric_config_with_string_rmse():
194194
195195def test_get_metric_config_with_string_mse ():
196196 """Test get_metric_config with 'mse' string."""
197- config = get_metric_config ("mse" )
197+ config = get_metric ("mse" )
198198
199199 assert config == MSE
200200 assert config .name == "mse"
@@ -203,7 +203,7 @@ def test_get_metric_config_with_string_mse():
203203
204204def test_get_metric_config_with_string_mae ():
205205 """Test get_metric_config with 'mae' string."""
206- config = get_metric_config ("mae" )
206+ config = get_metric ("mae" )
207207
208208 assert config == MAE
209209 assert config .name == "mae"
@@ -212,9 +212,9 @@ def test_get_metric_config_with_string_mae():
212212
213213def test_get_metric_config_case_insensitive ():
214214 """Test get_metric_config is case insensitive."""
215- config_upper = get_metric_config ("R2" )
216- config_lower = get_metric_config ("r2" )
217- config_mixed = get_metric_config ("R2" )
215+ config_upper = get_metric ("R2" )
216+ config_lower = get_metric ("r2" )
217+ config_mixed = get_metric ("R2" )
218218
219219 assert config_upper == config_lower == config_mixed == R2
220220
@@ -225,7 +225,7 @@ def test_get_metric_config_with_torchmetrics_instance():
225225 metric = torchmetrics .R2Score , name = "custom_r2" , maximize = True
226226 )
227227
228- config = get_metric_config (custom_metric )
228+ config = get_metric (custom_metric )
229229
230230 assert config == custom_metric
231231 assert config .name == "custom_r2"
@@ -234,7 +234,7 @@ def test_get_metric_config_with_torchmetrics_instance():
234234def test_get_metric_config_invalid_string ():
235235 """Test get_metric_config with invalid string raises ValueError."""
236236 with pytest .raises (ValueError , match = "Unknown metric shortcut" ) as excinfo :
237- get_metric_config ("invalid_metric" )
237+ get_metric ("invalid_metric" )
238238
239239 assert "Unknown metric shortcut" in str (excinfo .value )
240240 assert "invalid_metric" in str (excinfo .value )
@@ -244,15 +244,15 @@ def test_get_metric_config_invalid_string():
244244def test_get_metric_config_unsupported_type ():
245245 """Test get_metric_config with unsupported type raises ValueError."""
246246 with pytest .raises (ValueError , match = "Unsupported metric type" ) as excinfo :
247- get_metric_config (123 ) # type: ignore[arg-type]
247+ get_metric (123 ) # type: ignore[arg-type]
248248
249249 assert "Unsupported metric type" in str (excinfo .value )
250250
251251
252252def test_get_metric_config_with_none ():
253253 """Test get_metric_config with None raises ValueError."""
254254 with pytest .raises (ValueError , match = "Unsupported metric type" ) as excinfo :
255- get_metric_config (None ) # type: ignore[arg-type]
255+ get_metric (None ) # type: ignore[arg-type]
256256
257257 assert "Unsupported metric type" in str (excinfo .value )
258258
@@ -263,7 +263,7 @@ def test_get_metric_config_with_none():
263263def test_get_metric_configs_with_strings ():
264264 """Test get_metric_configs with list of strings."""
265265 metrics = ["r2" , "rmse" , "mse" ]
266- configs = get_metric_configs (metrics )
266+ configs = get_metrics (metrics )
267267
268268 assert len (configs ) == 3
269269 assert configs [0 ] == R2
@@ -278,7 +278,7 @@ def test_get_metric_configs_with_mixed_types():
278278 )
279279
280280 metrics = ["r2" , custom_metric , "mse" ]
281- configs = get_metric_configs (metrics )
281+ configs = get_metrics (metrics )
282282
283283 assert len (configs ) == 3
284284 assert configs [0 ] == R2
@@ -288,15 +288,15 @@ def test_get_metric_configs_with_mixed_types():
288288
289289def test_get_metric_configs_with_empty_list ():
290290 """Test get_metric_configs with empty list."""
291- configs = get_metric_configs ([])
291+ configs = get_metrics ([])
292292
293293 assert len (configs ) == 0
294294 assert configs == []
295295
296296
297297def test_get_metric_configs_with_single_metric ():
298298 """Test get_metric_configs with single metric."""
299- configs = get_metric_configs (["r2" ])
299+ configs = get_metrics (["r2" ])
300300
301301 assert len (configs ) == 1
302302 assert configs [0 ] == R2
@@ -305,7 +305,7 @@ def test_get_metric_configs_with_single_metric():
305305def test_get_metric_configs_with_all_available_metrics ():
306306 """Test get_metric_configs with all available metrics."""
307307 metrics = list (AVAILABLE_METRICS .keys ())
308- configs = get_metric_configs (metrics )
308+ configs = get_metrics (metrics )
309309
310310 assert len (configs ) == len (AVAILABLE_METRICS )
311311
@@ -320,7 +320,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
320320 metric = torchmetrics .MeanSquaredError , name = "mse_1" , maximize = False
321321 )
322322
323- configs = get_metric_configs ([metric1 , metric2 ])
323+ configs = get_metrics ([metric1 , metric2 ])
324324
325325 assert len (configs ) == 2
326326 assert configs [0 ] == metric1
@@ -330,7 +330,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
330330def test_get_metric_configs_case_insensitive ():
331331 """Test get_metric_configs is case insensitive for strings."""
332332 metrics = ["R2" , "RMSE" , "mse" , "MaE" , "Crps" ]
333- configs = get_metric_configs (metrics )
333+ configs = get_metrics (metrics )
334334
335335 assert len (configs ) == 5
336336 assert configs [0 ] == R2
@@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors():
396396def test_metric_configs_workflow ():
397397 """Test complete workflow of getting and using metric configs."""
398398 # Get configs from strings
399- configs = get_metric_configs (["r2" , "rmse" ])
399+ metrics = get_metrics (["r2" , "rmse" ])
400400
401401 # Use configs to compute metrics
402402 y_pred = torch .tensor ([1.0 , 2.0 , 3.0 ])
403403 y_true = torch .tensor ([1.0 , 2.0 , 3.0 ])
404404
405- results = {}
406- for config in configs :
407- metric = config .metric ()
408- metric .update (y_pred , y_true )
409- results [config .name ] = metric .compute ()
410-
405+ results = {metric .name : metric (y_pred , y_true ) for metric in metrics }
411406 assert "r2" in results
412407 assert "rmse" in results
413408 assert torch .isclose (results ["r2" ], torch .tensor (1.0 )) # Perfect R2
@@ -509,7 +504,7 @@ def test_crps_aggregation_across_batch():
509504
510505def test_get_metric_config_crps ():
511506 """Test get_metric_config with 'crps' string."""
512- config = get_metric_config ("crps" )
507+ config = get_metric ("crps" )
513508
514509 assert config == CRPS
515510 assert isinstance (config , CRPSMetric )
0 commit comments