Skip to content

Commit 7853578

Browse files
Statespace test cleanup
1 parent 8b8dda7 commit 7853578

File tree

3 files changed

+10
-18
lines changed

3 files changed

+10
-18
lines changed

tests/statespace/core/test_statespace_JAX.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tests.statespace.test_utilities import load_nile_test_data
2525

2626
pytest.importorskip("jax")
27-
pytest.importorskip("numpyro")
27+
pytest.importorskip("nutpie")
2828

2929

3030
floatX = pytensor.config.floatX
@@ -78,7 +78,8 @@ def idata(pymc_mod, rng, mock_pymc_sample):
7878
tune=1,
7979
chains=1,
8080
random_seed=rng,
81-
nuts_sampler="numpyro",
81+
nuts_sampler="nutpie",
82+
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
8283
progressbar=False,
8384
)
8485
with freeze_dims_and_data(pymc_mod):
@@ -101,7 +102,8 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
101102
tune=1,
102103
chains=1,
103104
random_seed=rng,
104-
nuts_sampler="numpyro",
105+
nuts_sampler="nutpie",
106+
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
105107
progressbar=False,
106108
)
107109
with freeze_dims_and_data(pymc_mod):
@@ -123,8 +125,7 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
123125
@pytest.mark.parametrize("kind", ["conditional", "unconditional"])
124126
def test_sampling_methods(group, kind, ss_mod, idata, rng):
125127
f = getattr(ss_mod, f"sample_{kind}_{group}")
126-
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
127-
test_idata = f(idata, random_seed=rng)
128+
test_idata = f(idata, random_seed=rng)
128129

129130
if kind == "conditional":
130131
for output in ["filtered", "predicted", "smoothed"]:
@@ -142,10 +143,9 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
142143
def test_forecast(filter_output, ss_mod, idata, rng):
143144
time_idx = idata.posterior.coords["time"].values
144145

145-
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
146-
forecast_idata = ss_mod.forecast(
147-
idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
148-
)
146+
forecast_idata = ss_mod.forecast(
147+
idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
148+
)
149149

150150
assert forecast_idata.coords["time"].values.shape == (10,)
151151
assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")

tests/statespace/filters/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def ss_mod_no_me():
9393
return ss_mod
9494

9595

96-
@pytest.mark.parametrize("kfilter", filter_names, ids=filter_names)
96+
@pytest.mark.parametrize("kfilter", filter_names)
9797
def test_loglike_vectors_agree(kfilter, pymc_model):
9898
# TODO: This test might be flakey, I've gotten random failures
9999
ss_mod = structural.LevelTrendComponent(order=2).build(

tests/statespace/models/test_SARIMAX.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -432,15 +432,7 @@ def test_SARIMA_with_exogenous(rng, mock_sample):
432432
obs_intercept = ss_mod.ssm["obs_intercept"]
433433
assert obs_intercept.type.shape == (None, ss_mod.k_endog)
434434

435-
intercept_fn = pytensor.function(
436-
inputs=list(explicit_graph_inputs(obs_intercept)), outputs=obs_intercept
437-
)
438435
data_val = rng.normal(size=(100, 2)).astype(floatX)
439-
beta_val = rng.normal(size=(2,)).astype(floatX)
440-
441-
intercept_val = intercept_fn(data_val, beta_val)
442-
np.testing.assert_allclose(intercept_val, intercept_fn(data_val, beta_val))
443-
444436
data_df = pd.DataFrame(
445437
rng.normal(size=(100, 1)),
446438
index=pd.date_range(start="2020-01-01", periods=100, freq="D"),

0 commit comments

Comments
 (0)