From 19330338a25510e88388686eef9a41648c05c895 Mon Sep 17 00:00:00 2001 From: colecitrenbaum Date: Sun, 30 Nov 2025 18:30:13 -0800 Subject: [PATCH 1/4] Fixed off by one indexing for nonstationary transitions, added test comparing hmm_two_filter_smoother and hmm_smoother for non-stationary kernel. --- dynamax/hidden_markov_model/inference.py | 3 +- dynamax/hidden_markov_model/inference_test.py | 28 +++++++++++++++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index e1c2f1bd..d891f2b8 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -189,7 +189,8 @@ def _step(carry, t): """Backward filtering step.""" log_normalizer, backward_pred_probs = carry - A = get_trans_mat(transition_matrix, transition_fn, t) + # A = get_trans_mat(transition_matrix, transition_fn, t) + A = get_trans_mat(transition_matrix, transition_fn, t-1) ll = log_likelihoods[t] # Condition on emission at time t, being careful not to overflow. diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index abb76add..49f9efff 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -7,7 +7,6 @@ from jax import vmap import dynamax.hidden_markov_model.inference as core import dynamax.hidden_markov_model.parallel_inference as parallel - from jax.scipy.special import logsumexp def big_log_joint(initial_probs, transition_matrix, log_likelihoods): @@ -56,12 +55,16 @@ def random_hmm_args_nonstationary(key, num_timesteps, num_states, scale=1.0): # we use numpy so we can assign to the matrix. # Then we convert to jnp. + keys_A = jr.split(k2, num_timesteps - 1) trans_mat = jnp.zeros((num_timesteps - 1, num_states, num_states)) - for t in range(num_timesteps): - A = jr.uniform(k2, (num_states, num_states)) - A /= A.sum(1, keepdims=True) - trans_mat = trans_mat.at[t].set(A) - return initial_probs, jnp.array(trans_mat), log_likelihoods + for t in range(num_timesteps - 1): + A = jr.uniform(keys_A[t], (num_states, num_states)) + A /= A.sum(1, keepdims=True) + trans_mat = trans_mat.at[t].set(A) + + return initial_probs, trans_mat, log_likelihoods + + def test_hmm_filter(key=0, num_timesteps=3, num_states=2): """ @@ -169,6 +172,16 @@ def test_hmm_smoother(key=0, num_timesteps=5, num_states=2): smoothed_probs_t = jnp.sum(joint, axis=tuple(jnp.arange(t)) + tuple(jnp.arange(t + 1, num_timesteps))) assert jnp.allclose(posterior.smoothed_probs[t], smoothed_probs_t, atol=1e-4) +def test_two_filter_vs_smoother_nonstationary(key=0, num_timesteps=6, num_states=3): + key = jr.PRNGKey(key) + init, A_t, log_lkhds = random_hmm_args_nonstationary(key, num_timesteps, num_states) + + post_two = core.hmm_two_filter_smoother(init, A_t, log_lkhds) + post_rts = core.hmm_smoother(init, A_t, log_lkhds) + + assert jnp.allclose(post_two.smoothed_probs, post_rts.smoothed_probs, atol=1e-4) + + def test_hmm_fixed_lag_smoother(key=0, num_timesteps=5, num_states=2): """ @@ -351,3 +364,6 @@ def test_parallel_posterior_sample( # Compare the joint distributions assert jnp.allclose(blj_sample, blj, rtol=0, atol=eps) +if __name__ == "__main__": + test_two_filter_vs_smoother_nonstationary() + test_hmm_non_stationary() From a9405ffdcc6ef63a27b2d67e22030a66911b4094 Mon Sep 17 00:00:00 2001 From: Cole Citrenbaum <89095168+colecitrenbaum@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:57:17 -0800 Subject: [PATCH 2/4] Fix transition matrix index in backward step Update transition matrix retrieval to use t-1. --- dynamax/hidden_markov_model/inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index d891f2b8..964499a8 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -188,8 +188,6 @@ def hmm_backward_filter( def _step(carry, t): """Backward filtering step.""" log_normalizer, backward_pred_probs = carry - - # A = get_trans_mat(transition_matrix, transition_fn, t) A = get_trans_mat(transition_matrix, transition_fn, t-1) ll = log_likelihoods[t] From 1819ee5aa8dcb16fb412864930d597da0f7955cc Mon Sep 17 00:00:00 2001 From: colecitrenbaum Date: Mon, 1 Dec 2025 19:17:13 -0800 Subject: [PATCH 3/4] change of jax versions in requirements --- dynamax/hidden_markov_model/inference_test.py | 5 ++--- dynamax/linear_gaussian_ssm/parallel_inference_test.py | 2 +- pyproject.toml | 8 ++++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index 49f9efff..b8aaa255 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax.random as jr from jax import vmap + import dynamax.hidden_markov_model.inference as core import dynamax.hidden_markov_model.parallel_inference as parallel from jax.scipy.special import logsumexp @@ -364,6 +365,4 @@ def test_parallel_posterior_sample( # Compare the joint distributions assert jnp.allclose(blj_sample, blj, rtol=0, atol=eps) -if __name__ == "__main__": - test_two_filter_vs_smoother_nonstationary() - test_hmm_non_stationary() + diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index c397929d..0f343549 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -15,7 +15,7 @@ from jax import vmap -allclose = partial(jnp.allclose, atol=1e-4) +allclose = partial(jnp.allclose, atol=1e-3) def make_static_lgssm_params(): """Create a static LGSSM with fixed parameters.""" diff --git a/pyproject.toml b/pyproject.toml index e6431c23..673943d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,15 +7,15 @@ name = "dynamax" dynamic = ["version"] requires-python = ">= 3.10" dependencies = [ - "jax>=0.3.15", - "jaxlib", + "jax==0.4.26", + "jaxlib==0.4.26", + "tensorflow-probability>=0.23,<0.25", "fastprogress", "optax", - "tensorflow_probability", "scikit-learn", "jaxtyping", "typing-extensions", - "numpy" + "numpy>=1.24,<2.0" ] authors = [ From ad77fa2e54e5f5648a36fd9db1e92ad0dd9e40dc Mon Sep 17 00:00:00 2001 From: Scott Linderman Date: Mon, 1 Dec 2025 21:42:24 -0800 Subject: [PATCH 4/4] unpinning jax versions and switching to tfp-nightly, which should work with the latest jax versions more reliably. --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 673943d9..bd7e910f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,15 +7,15 @@ name = "dynamax" dynamic = ["version"] requires-python = ">= 3.10" dependencies = [ - "jax==0.4.26", - "jaxlib==0.4.26", - "tensorflow-probability>=0.23,<0.25", + "jax", + "jaxlib", + "tfp-nightly", "fastprogress", "optax", "scikit-learn", "jaxtyping", "typing-extensions", - "numpy>=1.24,<2.0" + "numpy" ] authors = [