1212
1313from autoemulate .core .device import TorchDeviceMixin
1414from autoemulate .core .logging_config import get_configured_logger
15- from autoemulate .core .model_selection import bootstrap , evaluate , r2_metric
15+ from autoemulate .core .metrics import (
16+ TorchMetrics ,
17+ get_metric_config ,
18+ get_metric_configs ,
19+ )
20+ from autoemulate .core .model_selection import bootstrap , evaluate
1621from autoemulate .core .plotting import (
1722 calculate_subplot_layout ,
1823 create_and_plot_slice ,
@@ -72,6 +77,8 @@ def __init__(
7277 device : DeviceLike | None = None ,
7378 random_seed : int | None = None ,
7479 log_level : str = "progress_bar" ,
80+ tuning_metric : str | TorchMetrics = "r2" ,
81+ evaluation_metrics : list [str | TorchMetrics ] | None = None ,
7582 ):
7683 """
7784 Initialize the AutoEmulate class.
@@ -122,13 +129,27 @@ def __init__(
122129 it will show a progress bar during model comparison. It will set the
123130 logging level to "error" to avoid cluttering the output
124131 with debug/info logs.
132+ tuning_metric: str | TorchMetrics
133+ Metric to use for hyperparameter tuning. Can be a string shortcut
134+ ("r2", "rmse", "mse", "mae") or a MetricConfig object. Defaults to "r2".
135+ evaluation_metrics: list[str | TorchMetrics] | None
136+ Metrics to compute during evaluation.
137+ If None, then defaults to ["r2", "rmse"].
138+ Each entry can be a string shortcut or a MetricConfig object.
139+ IMPORTANT: The first metric in the list is used to
140+ determine the best model.
125141 """
126142 Results .__init__ (self )
127143 self .random_seed = random_seed
128144 TorchDeviceMixin .__init__ (self , device = device )
129145 x , y = self ._convert_to_tensors (x , y )
130146 x , y = self ._move_tensors_to_device (x , y )
131147
148+ # Setup metrics. If evaluation_metrics is None, default to ["r2", "rmse"]
149+ evaluation_metrics = evaluation_metrics or ["r2" , "rmse" ]
150+ self .evaluation_metrics = get_metric_configs (evaluation_metrics )
151+ self .tuning_metric = get_metric_config (tuning_metric )
152+
132153 # Transforms to search over
133154 self .x_transforms_list = [
134155 self .get_transforms (transforms )
@@ -323,18 +344,19 @@ def log_compare(
323344 x_transforms ,
324345 y_transforms ,
325346 best_params_for_this_model ,
326- r2_score ,
327- rmse_score ,
347+ test_metrics ,
328348 ):
329349 """Log the comparison results."""
350+ metrics_str = ", " .join (
351+ f"{ metric } : { mean :.3f} " for metric , (mean , _std ) in test_metrics .items ()
352+ )
330353 msg = (
331354 "Comparison results:\n "
332355 f"Best Model: { best_model_name } , "
333356 f"x transforms: { x_transforms } , "
334- f"y transforms: { y_transforms } " ,
357+ f"y transforms: { y_transforms } , "
335358 f"Best params: { best_params_for_this_model } , "
336- f"R2 score: { r2_score :.3f} , "
337- f"RMSE score: { rmse_score :.3f} " ,
359+ f"Metrics: { metrics_str } "
338360 )
339361 self .logger .debug (msg )
340362
@@ -351,7 +373,13 @@ def compare(self):
351373 - Log the results.
352374 - Save the best model and its parameters.
353375 """
354- tuner = Tuner (self .train_val , y = None , n_iter = self .n_iter , device = self .device )
376+ tuner = Tuner (
377+ self .train_val ,
378+ y = None ,
379+ n_iter = self .n_iter ,
380+ device = self .device ,
381+ tuning_metric = self .tuning_metric ,
382+ )
355383 self .logger .info (
356384 "Comparing %s" , [model_cls .__name__ for model_cls in self .models ]
357385 )
@@ -393,7 +421,11 @@ def compare(self):
393421 mean_scores = [
394422 np .mean (score ).item () for score in scores
395423 ]
396- best_score_idx = np .argmax (mean_scores )
424+ # Select best whether we're maximizing or minimizing
425+ if self .tuning_metric .maximize :
426+ best_score_idx = np .argmax (mean_scores )
427+ else :
428+ best_score_idx = np .argmin (mean_scores )
397429 best_params_for_this_model = params_list [best_score_idx ]
398430 self .logger .debug (
399431 'Tuner found best params for model "%s": '
@@ -445,35 +477,33 @@ def compare(self):
445477 # This can fail for some model params
446478 transformed_emulator .fit (train_val_x , train_val_y )
447479
448- (
449- (r2_train_val , r2_train_val_std ),
450- (rmse_train_val , rmse_train_val_std ),
451- ) = bootstrap (
480+ train_metrics = bootstrap (
452481 transformed_emulator ,
453482 train_val_x ,
454483 train_val_y ,
455484 n_bootstraps = self .n_bootstraps ,
456485 device = self .device ,
486+ metrics = self .evaluation_metrics ,
457487 )
458- (r2_test , r2_test_std ), (rmse_test , rmse_test_std ) = (
459- bootstrap (
460- transformed_emulator ,
461- test_x ,
462- test_y ,
463- n_bootstraps = self .n_bootstraps ,
464- device = self .device ,
465- )
488+ test_metrics = bootstrap (
489+ transformed_emulator ,
490+ test_x ,
491+ test_y ,
492+ n_bootstraps = self .n_bootstraps ,
493+ device = self .device ,
494+ metrics = self .evaluation_metrics ,
466495 )
467496
497+ # Log all test metrics from test_metrics dictionary
498+ test_metrics_str = ", " .join (
499+ f"{ metric } : { mean :.3f} (std: { std :.3f} )"
500+ for metric , (mean , std ) in test_metrics .items ()
501+ )
468502 self .logger .debug (
469- 'Cross-validation for model "%s"'
470- " completed with test mean (std) R2 score: %.3f (%.3f),"
471- " mean (std) RMSE score: %.3f (%.3f)" ,
503+ 'Cross-validation for model "%s" '
504+ "completed with test metrics: %s" ,
472505 model_cls .__name__ ,
473- r2_test ,
474- r2_test_std ,
475- rmse_test ,
476- rmse_test_std ,
506+ test_metrics_str ,
477507 )
478508 self .logger .info (
479509 "Finished running Model: %s\n " , model_cls .__name__
@@ -483,14 +513,8 @@ def compare(self):
483513 model_name = transformed_emulator .untransformed_model_name ,
484514 model = transformed_emulator ,
485515 params = best_params_for_this_model ,
486- r2_test = r2_test ,
487- rmse_test = rmse_test ,
488- r2_test_std = r2_test_std ,
489- rmse_test_std = rmse_test_std ,
490- r2_train = r2_train_val ,
491- rmse_train = rmse_train_val ,
492- r2_train_std = r2_train_val_std ,
493- rmse_train_std = rmse_train_val_std ,
516+ test_metrics = test_metrics ,
517+ train_metrics = train_metrics ,
494518 )
495519 self .add_result (result )
496520 # if successful, break out of the retry loop
@@ -511,14 +535,17 @@ def compare(self):
511535 )
512536
513537 # Get the best result and log the comparison
514- best_result = self .best_result ()
538+ # Use the first evaluation metric to determine the best result
539+ first_metric = self .evaluation_metrics [0 ]
540+ best_result = self .best_result (
541+ metric_name = first_metric .name ,
542+ )
515543 self .log_compare (
516544 best_model_name = best_result .model_name ,
517545 x_transforms = best_result .x_transforms ,
518546 y_transforms = best_result .y_transforms ,
519547 best_params_for_this_model = best_result .params ,
520- r2_score = best_result .r2_test ,
521- rmse_score = best_result .rmse_test ,
548+ test_metrics = best_result .test_metrics ,
522549 )
523550
524551 def fit_from_reinitialized (
@@ -642,7 +669,7 @@ def plot( # noqa: PLR0912, PLR0915
642669
643670 # Re-run prediction with just this model to get the predictions
644671 y_pred , y_variance = model .predict_mean_and_variance (test_x )
645- r2_score = evaluate (y_pred , test_y , r2_metric () )
672+ r2_score = evaluate (y_pred , test_y )
646673
647674 # Handle ranges
648675 input_ranges = input_ranges or {}
0 commit comments