Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions autoemulate/emulators/gaussian_process/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
GaussianProcessLike,
TensorLike,
)
from autoemulate.data.utils import set_random_seed
from autoemulate.emulators.base import GaussianProcessEmulator
from autoemulate.emulators.gaussian_process import CovarModuleFn, MeanModuleFn
from autoemulate.transforms.standardize import StandardizeTransform
Expand Down Expand Up @@ -358,7 +357,6 @@ def __init__(
epochs: int = 50,
lr: float = 2e-1,
early_stopping: EarlyStopping | None = None,
seed: int | None = None,
device: DeviceLike | None = None,
scheduler_cls: type[LRScheduler] | None = None,
scheduler_params: dict | None = None,
Expand Down Expand Up @@ -398,8 +396,6 @@ def __init__(
Learning rate for the optimizer. Defaults to 2e-1.
early_stopping: EarlyStopping | None
An optional EarlyStopping callback. Defaults to None.
seed: int | None
Random seed for reproducibility. If None, no seed is set. Defaults to None.
device: DeviceLike | None
Device to run the model on. If None, uses the default device (usually CPU or
GPU). Defaults to None.
Expand All @@ -411,9 +407,6 @@ def __init__(
# Init device
TorchDeviceMixin.__init__(self, device=device)

if seed is not None:
set_random_seed(seed)

# Convert to 2D tensors if needed and move to device
x, y = self._move_tensors_to_device(*self._convert_to_tensors(x, y))

Expand Down Expand Up @@ -570,18 +563,18 @@ def __init__(
super().__init__(
x,
y,
standardize_x,
standardize_y,
likelihood_cls,
mean_module_fn,
covar_module_fn,
fixed_mean_params,
fixed_covar_params,
posterior_predictive,
epochs,
lr,
early_stopping,
device,
standardize_x=standardize_x,
standardize_y=standardize_y,
likelihood_cls=likelihood_cls,
mean_module_fn=mean_module_fn,
covar_module_fn=covar_module_fn,
fixed_mean_params=fixed_mean_params,
fixed_covar_params=fixed_covar_params,
posterior_predictive=posterior_predictive,
epochs=epochs,
lr=lr,
early_stopping=early_stopping,
device=device,
**scheduler_params,
)

Expand Down
9 changes: 6 additions & 3 deletions tests/emulators/test_gaussian_process_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,12 @@ def test_gp_corr_deterministic_with_seed(sample_data_y1d, new_data_y1d, device):
x2, _ = new_data_y1d
seed = 42
new_seed = 43
model1 = GaussianProcessCorrelated(x, y, device=device, seed=seed)
model2 = GaussianProcessCorrelated(x, y, device=device, seed=new_seed)
model3 = GaussianProcessCorrelated(x, y, device=device, seed=seed)
set_random_seed(seed)
model1 = GaussianProcessCorrelated(x, y, device=device)
set_random_seed(new_seed)
model2 = GaussianProcessCorrelated(x, y, device=device)
set_random_seed(seed)
model3 = GaussianProcessCorrelated(x, y, device=device)
model1.fit(x, y)
pred1 = model1.predict(x2)
model2.fit(x, y)
Expand Down