Skip to content

Conversation

@colecitrenbaum
Copy link

Added test comparing hmm_two_filter_smoother and hmm_smoother for non-stationary kernel.

…omparing hmm_two_filter_smoother and hmm_smoother for non-stationary kernel.
Update transition matrix retrieval to use t-1.
@colecitrenbaum
Copy link
Author

colecitrenbaum commented Dec 2, 2025

Test failure turned out to be a disagreement between Jax and Tensorflow which is resolved by using slightly older Jax in requirements.

Tests now pass except:

  1. one of the allclose in linear gaussian ssm requires atol = 1e-3, not atol = 1e-4. and
  2. test_parameter_pytree_jittable in parameters_test According to ChatGPT, the meaning of _cache_size changed in Jax with Python 3.12. Confusingly, python -m pytest dynamax/parameters_test.py passes, but python -m pytest --cov=./ -k parameters_test -s (as in Github action) fails this test.

I don't really understand what this test is doing because I'm new to Jax-- any guidance? Neither test is related (as far as I can tell) to the change I made in the backward pass.


def test_parameter_pytree_jittable():
    """Test that the parameter PyTree is jittable"""
    # If there's a problem with our PyTree registration, this should catch it.
    params, props = make_params()

    @jit
    def get_trainable(params, props):
        """Return a PyTree of trainable parameters"""
        return tree_map(lambda node, prop: node if prop.trainable else None,
                        params, props,
                        is_leaf=lambda node: isinstance(node, ParameterProperties))

    # first call, jit
    get_trainable(params, props)
    assert get_trainable._cache_size() == 1

    # change param values, don't jit
    params = params._replace(initial=params.initial._replace(probs=jnp.zeros(3)))
    get_trainable(params, props)
    assert get_trainable._cache_size() == 1

    # change param dtype, jit
    params = params._replace(initial=params.initial._replace(probs=jnp.zeros(3, dtype=int)))
    get_trainable(params, props)
    assert get_trainable._cache_size() == 2

    # change props, jit
    props.transitions.transition_matrix.trainable = False
    get_trainable(params, props)
    assert get_trainable._cache_size() == 3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants