2424from tests .statespace .test_utilities import load_nile_test_data
2525
2626pytest .importorskip ("jax" )
27- pytest .importorskip ("numpyro " )
27+ pytest .importorskip ("nutpie " )
2828
2929
3030floatX = pytensor .config .floatX
@@ -78,7 +78,8 @@ def idata(pymc_mod, rng, mock_pymc_sample):
7878 tune = 1 ,
7979 chains = 1 ,
8080 random_seed = rng ,
81- nuts_sampler = "numpyro" ,
81+ nuts_sampler = "nutpie" ,
82+ nuts_sampler_kwargs = {"backend" : "jax" , "gradient_backend" : "jax" },
8283 progressbar = False ,
8384 )
8485 with freeze_dims_and_data (pymc_mod ):
@@ -101,7 +102,8 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
101102 tune = 1 ,
102103 chains = 1 ,
103104 random_seed = rng ,
104- nuts_sampler = "numpyro" ,
105+ nuts_sampler = "nutpie" ,
106+ nuts_sampler_kwargs = {"backend" : "jax" , "gradient_backend" : "jax" },
105107 progressbar = False ,
106108 )
107109 with freeze_dims_and_data (pymc_mod ):
@@ -123,8 +125,7 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
123125@pytest .mark .parametrize ("kind" , ["conditional" , "unconditional" ])
124126def test_sampling_methods (group , kind , ss_mod , idata , rng ):
125127 f = getattr (ss_mod , f"sample_{ kind } _{ group } " )
126- with pytest .warns (UserWarning , match = "The RandomType SharedVariables" ):
127- test_idata = f (idata , random_seed = rng )
128+ test_idata = f (idata , random_seed = rng )
128129
129130 if kind == "conditional" :
130131 for output in ["filtered" , "predicted" , "smoothed" ]:
@@ -142,10 +143,9 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
142143def test_forecast (filter_output , ss_mod , idata , rng ):
143144 time_idx = idata .posterior .coords ["time" ].values
144145
145- with pytest .warns (UserWarning , match = "The RandomType SharedVariables" ):
146- forecast_idata = ss_mod .forecast (
147- idata , start = time_idx [- 1 ], periods = 10 , filter_output = filter_output , random_seed = rng
148- )
146+ forecast_idata = ss_mod .forecast (
147+ idata , start = time_idx [- 1 ], periods = 10 , filter_output = filter_output , random_seed = rng
148+ )
149149
150150 assert forecast_idata .coords ["time" ].values .shape == (10 ,)
151151 assert forecast_idata .forecast_latent .dims == ("chain" , "draw" , "time" , "state" )
0 commit comments