Skip to content

Commit 08b732c

Browse files
committed
Add n_samples to APIs
1 parent 9a1d9e2 commit 08b732c

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

autoemulate/core/compare.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
log_level: str = "progress_bar",
7777
tuning_metric: str | Metric = "r2",
7878
evaluation_metrics: list[str | Metric] | None = None,
79+
n_samples: int = 1000,
7980
):
8081
"""
8182
Initialize the AutoEmulate class.
@@ -135,6 +136,9 @@ def __init__(
135136
Each entry can be a string shortcut or a MetricConfig object.
136137
IMPORTANT: The first metric in the list is used to
137138
determine the best model.
139+
n_samples: int
140+
Number of samples to generate to predict mean when emulator does not have a
141+
mean directly available. Defaults to 1000.
138142
"""
139143
Results.__init__(self)
140144
self.random_seed = random_seed
@@ -192,6 +196,7 @@ def __init__(
192196
# Set up logger and ModelSerialiser for saving models
193197
self.logger, self.progress_bar = get_configured_logger(log_level)
194198
self.model_serialiser = ModelSerialiser(self.logger)
199+
self.n_samples = n_samples
195200

196201
# Run compare
197202
self.compare()
@@ -489,6 +494,7 @@ def compare(self):
489494
n_bootstraps=self.n_bootstraps,
490495
device=self.device,
491496
metrics=self.evaluation_metrics,
497+
n_samples=self.n_samples,
492498
)
493499
test_metrics = bootstrap(
494500
transformed_emulator,
@@ -497,6 +503,7 @@ def compare(self):
497503
n_bootstraps=self.n_bootstraps,
498504
device=self.device,
499505
metrics=self.evaluation_metrics,
506+
n_samples=self.n_samples,
500507
)
501508

502509
# Log all test metrics from test_metrics dictionary

autoemulate/core/model_selection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def cross_validate(
6464
device: DeviceLike = "cpu",
6565
random_seed: int | None = None,
6666
metrics: list[Metric] | None = None,
67+
n_samples: int = 1000,
6768
):
6869
"""
6970
Cross validate model performance using the given `cv` strategy.
@@ -88,6 +89,9 @@ def cross_validate(
8889
Optional random seed for reproducibility.
8990
metrics: list[TorchMetrics] | None
9091
List of metrics to compute. If None, uses r2 and rmse.
92+
n_samples: int
93+
Number of samples to generate to predict mean when emulator does not have a
94+
mean directly available. Defaults to 1000.
9195
9296
Returns
9397
-------
@@ -146,7 +150,7 @@ def cross_validate(
146150
# compute and save results
147151
y_pred = transformed_emulator.predict(x_val)
148152
for metric in metrics:
149-
score = evaluate(y_pred, y_val, metric)
153+
score = evaluate(y_pred, y_val, metric, n_samples=n_samples)
150154
cv_results[metric.name].append(score)
151155
return cv_results
152156

0 commit comments

Comments
 (0)