diff --git a/autoemulate/emulators/gaussian_process/exact.py b/autoemulate/emulators/gaussian_process/exact.py index 3cba175b1..2c0536549 100644 --- a/autoemulate/emulators/gaussian_process/exact.py +++ b/autoemulate/emulators/gaussian_process/exact.py @@ -18,7 +18,6 @@ GaussianProcessLike, TensorLike, ) -from autoemulate.data.utils import set_random_seed from autoemulate.emulators.base import GaussianProcessEmulator from autoemulate.emulators.gaussian_process import CovarModuleFn, MeanModuleFn from autoemulate.transforms.standardize import StandardizeTransform @@ -358,7 +357,6 @@ def __init__( epochs: int = 50, lr: float = 2e-1, early_stopping: EarlyStopping | None = None, - seed: int | None = None, device: DeviceLike | None = None, scheduler_cls: type[LRScheduler] | None = None, scheduler_params: dict | None = None, @@ -398,8 +396,6 @@ def __init__( Learning rate for the optimizer. Defaults to 2e-1. early_stopping: EarlyStopping | None An optional EarlyStopping callback. Defaults to None. - seed: int | None - Random seed for reproducibility. If None, no seed is set. Defaults to None. device: DeviceLike | None Device to run the model on. If None, uses the default device (usually CPU or GPU). Defaults to None. @@ -411,9 +407,6 @@ def __init__( # Init device TorchDeviceMixin.__init__(self, device=device) - if seed is not None: - set_random_seed(seed) - # Convert to 2D tensors if needed and move to device x, y = self._move_tensors_to_device(*self._convert_to_tensors(x, y)) @@ -570,18 +563,18 @@ def __init__( super().__init__( x, y, - standardize_x, - standardize_y, - likelihood_cls, - mean_module_fn, - covar_module_fn, - fixed_mean_params, - fixed_covar_params, - posterior_predictive, - epochs, - lr, - early_stopping, - device, + standardize_x=standardize_x, + standardize_y=standardize_y, + likelihood_cls=likelihood_cls, + mean_module_fn=mean_module_fn, + covar_module_fn=covar_module_fn, + fixed_mean_params=fixed_mean_params, + fixed_covar_params=fixed_covar_params, + posterior_predictive=posterior_predictive, + epochs=epochs, + lr=lr, + early_stopping=early_stopping, + device=device, **scheduler_params, ) diff --git a/tests/emulators/test_gaussian_process_exact.py b/tests/emulators/test_gaussian_process_exact.py index 55b1baad8..abcf88103 100644 --- a/tests/emulators/test_gaussian_process_exact.py +++ b/tests/emulators/test_gaussian_process_exact.py @@ -167,9 +167,12 @@ def test_gp_corr_deterministic_with_seed(sample_data_y1d, new_data_y1d, device): x2, _ = new_data_y1d seed = 42 new_seed = 43 - model1 = GaussianProcessCorrelated(x, y, device=device, seed=seed) - model2 = GaussianProcessCorrelated(x, y, device=device, seed=new_seed) - model3 = GaussianProcessCorrelated(x, y, device=device, seed=seed) + set_random_seed(seed) + model1 = GaussianProcessCorrelated(x, y, device=device) + set_random_seed(new_seed) + model2 = GaussianProcessCorrelated(x, y, device=device) + set_random_seed(seed) + model3 = GaussianProcessCorrelated(x, y, device=device) model1.fit(x, y) pred1 = model1.predict(x2) model2.fit(x, y)