17
17
import numpy as np
18
18
import numpy .typing as npt
19
19
from numba import njit
20
+ from pymc .initial_point import PointType
20
21
from pymc .model import Model , modelcontext
21
22
from pymc .pytensorf import inputvars , join_nonshared_inputs , make_shared_replacements
22
23
from pymc .step_methods .arraystep import ArrayStepShared
@@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915
125
126
num_particles : int = 10 ,
126
127
batch : tuple [float , float ] = (0.1 , 0.1 ),
127
128
model : Optional [Model ] = None ,
129
+ initial_point : PointType | None = None ,
130
+ compile_kwargs : dict | None = None , # pylint: disable=unused-argument
128
131
):
129
132
model = modelcontext (model )
130
- initial_values = model .initial_point ()
133
+ if initial_point is None :
134
+ initial_point = model .initial_point ()
131
135
if vars is None :
132
136
vars = model .value_vars
133
137
else :
@@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915
150
154
self .m = self .bart .m
151
155
self .response = self .bart .response
152
156
153
- shape = initial_values [value_bart .name ].shape
157
+ shape = initial_point [value_bart .name ].shape
154
158
155
159
self .shape = 1 if len (shape ) == 1 else shape [0 ]
156
160
@@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915
217
221
218
222
self .num_particles = num_particles
219
223
self .indices = list (range (1 , num_particles ))
220
- shared = make_shared_replacements (initial_values , vars , model )
221
- self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
224
+ shared = make_shared_replacements (initial_point , vars , model )
225
+ self .likelihood_logp = logp (initial_point , [model .datalogp ], vars , shared )
222
226
self .all_particles = [
223
227
[ParticleTree (self .a_tree ) for _ in range (self .m )] for _ in range (self .trees_shape )
224
228
]
0 commit comments