Skip to content

Commit 2d5adb6

Browse files
committed
Update test_grads for Conformal emulators
1 parent 2404b9b commit 2d5adb6

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

tests/emulators/test_grads.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from autoemulate.core.types import TensorLike
66
from autoemulate.emulators import GAUSSIAN_PROCESS_EMULATORS, PYTORCH_EMULATORS
7+
from autoemulate.emulators.conformal import Conformal, ConformalMLP
78
from autoemulate.emulators.gaussian_process.exact import GaussianProcess
89
from autoemulate.emulators.transformed.base import TransformedEmulator
910
from autoemulate.transforms.pca import PCATransform
@@ -18,8 +19,8 @@
1819

1920

2021
def get_pytest_param_yof(model, x_t, y_t, o, f):
21-
return (
22-
pytest.param(
22+
if o and f and model.supports_uq:
23+
return pytest.param(
2324
model,
2425
x_t,
2526
y_t,
@@ -30,9 +31,21 @@ def get_pytest_param_yof(model, x_t, y_t, o, f):
3031
reason="Full covariance sampling not implemented",
3132
),
3233
)
33-
if (o and f and model.supports_uq)
34-
else (model, x_t, y_t, o, f)
35-
)
34+
35+
if (not o) and issubclass(model, Conformal):
36+
return pytest.param(
37+
model,
38+
x_t,
39+
y_t,
40+
o,
41+
f,
42+
marks=pytest.mark.xfail(
43+
raises=ValueError,
44+
reason="Conformal emulators require sampling for predictions",
45+
),
46+
)
47+
48+
return (model, x_t, y_t, o, f)
3649

3750

3851
def get_parametrize_cases():
@@ -52,7 +65,8 @@ def get_parametrize_cases():
5265
output_from_samples_and_full_covariance_cases_cases = [
5366
get_pytest_param_yof(model, x_t, y_t, o, f)
5467
for model, x_t, y_t, o, f in itertools.product(
55-
[GaussianProcess],
68+
# ConformalMLP also included here as not tested above since
69+
[GaussianProcess, ConformalMLP],
5670
X_TRANSFORMS,
5771
Y_TRANSFORMS,
5872
[False, True],

0 commit comments

Comments
 (0)