I came across unexpected behavior for stochastic next functions with the _period argument.
For N_PERIODS = 3 and a binary state I thought we needed a 2x2 matrix (one for each period transition). But the following does not throw an error:
@lcm.mark.stochastic
def next_dummy_state(_period):
pass
DUMMY_STATE_TRANSITION = jnp.array(
[[0.5, 0.5]]
)
I checked it's actually included in the model, providing an empty transition matrices raises an indexing error as expected.
Providing a matrix that's "too large" (e.g. (4, 2)) here also doesn't result in an error.