@@ -76,6 +76,7 @@ def __init__(
7676 log_level : str = "progress_bar" ,
7777 tuning_metric : str | Metric = "r2" ,
7878 evaluation_metrics : list [str | Metric ] | None = None ,
79+ n_samples : int = 1000 ,
7980 ):
8081 """
8182 Initialize the AutoEmulate class.
@@ -135,6 +136,9 @@ def __init__(
135136 Each entry can be a string shortcut or a MetricConfig object.
136137 IMPORTANT: The first metric in the list is used to
137138 determine the best model.
139+ n_samples: int
140+ Number of samples to generate to predict mean when emulator does not have a
141+ mean directly available. Defaults to 1000.
138142 """
139143 Results .__init__ (self )
140144 self .random_seed = random_seed
@@ -192,6 +196,7 @@ def __init__(
192196 # Set up logger and ModelSerialiser for saving models
193197 self .logger , self .progress_bar = get_configured_logger (log_level )
194198 self .model_serialiser = ModelSerialiser (self .logger )
199+ self .n_samples = n_samples
195200
196201 # Run compare
197202 self .compare ()
@@ -489,6 +494,7 @@ def compare(self):
489494 n_bootstraps = self .n_bootstraps ,
490495 device = self .device ,
491496 metrics = self .evaluation_metrics ,
497+ n_samples = self .n_samples ,
492498 )
493499 test_metrics = bootstrap (
494500 transformed_emulator ,
@@ -497,6 +503,7 @@ def compare(self):
497503 n_bootstraps = self .n_bootstraps ,
498504 device = self .device ,
499505 metrics = self .evaluation_metrics ,
506+ n_samples = self .n_samples ,
500507 )
501508
502509 # Log all test metrics from test_metrics dictionary
0 commit comments