22
33import pandas as pd
44
5- from autoemulate .core .metrics import AVAILABLE_METRICS
5+ from autoemulate .core .metrics import Metric , get_metric
66from autoemulate .core .types import ModelParams
77from autoemulate .emulators .transformed .base import TransformedEmulator
88
@@ -18,8 +18,8 @@ def __init__(
1818 model_name : str ,
1919 model : TransformedEmulator ,
2020 params : ModelParams ,
21- test_metrics : dict [str , tuple [float , float ]],
22- train_metrics : dict [str , tuple [float , float ]],
21+ test_metrics : dict [Metric , tuple [float , float ]],
22+ train_metrics : dict [Metric , tuple [float , float ]],
2323 ):
2424 """Initialize a Result object.
2525
@@ -141,32 +141,32 @@ def summarize(self) -> pd.DataFrame:
141141 "params" : [result .params for result in self .results ],
142142 }
143143
144- # Collect all unique metric names from all results
144+ # Collect all unique metrics from all results
145145 all_test_metrics = set ()
146146 all_train_metrics = set ()
147147 for result in self .results :
148148 all_test_metrics .update (result .test_metrics .keys ())
149149 all_train_metrics .update (result .train_metrics .keys ())
150150
151151 # Add test metrics columns
152- for metric_name in sorted (all_test_metrics ):
153- data [f"{ metric_name } _test" ] = [
154- result .test_metrics .get (metric_name , (float ("nan" ), float ("nan" )))[0 ]
152+ for metric in sorted (all_test_metrics ):
153+ data [f"{ metric } _test" ] = [
154+ result .test_metrics .get (metric , (float ("nan" ), float ("nan" )))[0 ]
155155 for result in self .results
156156 ]
157- data [f"{ metric_name } _test_std" ] = [
158- result .test_metrics .get (metric_name , (float ("nan" ), float ("nan" )))[1 ]
157+ data [f"{ metric } _test_std" ] = [
158+ result .test_metrics .get (metric , (float ("nan" ), float ("nan" )))[1 ]
159159 for result in self .results
160160 ]
161161
162162 # Add train metrics columns
163- for metric_name in sorted (all_train_metrics ):
164- data [f"{ metric_name } _train" ] = [
165- result .train_metrics .get (metric_name , (float ("nan" ), float ("nan" )))[0 ]
163+ for metric in sorted (all_train_metrics ):
164+ data [f"{ metric } _train" ] = [
165+ result .train_metrics .get (metric , (float ("nan" ), float ("nan" )))[0 ]
166166 for result in self .results
167167 ]
168- data [f"{ metric_name } _train_std" ] = [
169- result .train_metrics .get (metric_name , (float ("nan" ), float ("nan" )))[1 ]
168+ data [f"{ metric } _train_std" ] = [
169+ result .train_metrics .get (metric , (float ("nan" ), float ("nan" )))[1 ]
170170 for result in self .results
171171 ]
172172
@@ -177,13 +177,13 @@ def summarize(self) -> pd.DataFrame:
177177
178178 summarise = summarize
179179
180- def best_result (self , metric_name : str | None = None ) -> Result :
180+ def best_result (self , metric : str | Metric | None = None ) -> Result :
181181 """
182182 Get the model with the best result based on the given metric.
183183
184184 Parameters
185185 ----------
186- metric_name : str | None
186+ metric : str | Metric | None
187187 The name of the metric to use for comparison. If None, uses the first
188188 available metric found in the results. The metric should exist in the
189189 test_metrics of the results.
@@ -202,51 +202,44 @@ def best_result(self, metric_name: str | None = None) -> Result:
202202 raise ValueError (msg )
203203
204204 # If metric_name is None, use the first available metric
205- if metric_name is None :
205+ if metric is None :
206206 # Collect all available metrics
207- available_metrics = set ()
208- for result in self .results :
209- available_metrics . update ( result . test_metrics . keys ())
207+ available_metrics = [
208+ m for result in self .results for m in result . test_metrics
209+ ]
210210
211211 if not available_metrics :
212212 msg = "No metrics available in results."
213213 raise ValueError (msg )
214214
215215 # Use the first metric
216- metric_name = next ( iter ( available_metrics ))
217- logger .info ("Using metric '%s' to determine best result." , metric_name )
216+ metric_selected = available_metrics [ 0 ]
217+ logger .info ("Using metric '%s' to determine best result." , metric_selected )
218218 else :
219219 # Check if the specified metric exists in at least one result
220- if not any (metric_name in result .test_metrics for result in self .results ):
220+ if not any (metric in result .test_metrics for result in self .results ):
221221 available_metrics = set ()
222222 for result in self .results :
223223 available_metrics .update (result .test_metrics .keys ())
224224 msg = (
225- f"Metric '{ metric_name } ' not found in any results. "
225+ f"Metric '{ metric } ' not found in any results. "
226226 f"Available metrics: { sorted (available_metrics )} "
227227 )
228228 raise ValueError (msg )
229-
230- logger .info ("Using metric '%s' to determine best result." , metric_name )
231-
232- # Determine if we are maximizing or minimizing the metric
233- # from the metric name
234- assert metric_name is not None # for pyright
235- metric_config = AVAILABLE_METRICS .get (metric_name )
236- if metric_config is None :
237- msg = f"Metric '{ metric_name } ' not found in AVAILABLE_METRICS."
238- raise ValueError (msg )
239- metric_maximize = metric_config .maximize
229+ metric_selected = get_metric (metric )
230+ logger .info ("Using metric '%s' to determine best result." , metric_selected )
240231
241232 # Select best result based on whether we're maximizing or minimizing
242- if metric_maximize :
233+ if metric_selected . maximize :
243234 return max (
244235 self .results ,
245- key = lambda r : r .test_metrics .get (metric_name , (float ("-inf" ), 0 ))[0 ],
236+ key = lambda r : r .test_metrics .get (metric_selected , (float ("-inf" ), 0 ))[
237+ 0
238+ ],
246239 )
247240 return min (
248241 self .results ,
249- key = lambda r : r .test_metrics .get (metric_name , (float ("inf" ), 0 ))[0 ],
242+ key = lambda r : r .test_metrics .get (metric_selected , (float ("inf" ), 0 ))[0 ],
250243 )
251244
252245 def get_result (self , result_id : int ) -> Result :
0 commit comments