Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ def hmm_backward_filter(
def _step(carry, t):
"""Backward filtering step."""
log_normalizer, backward_pred_probs = carry

A = get_trans_mat(transition_matrix, transition_fn, t)
A = get_trans_mat(transition_matrix, transition_fn, t-1)
ll = log_likelihoods[t]

# Condition on emission at time t, being careful not to overflow.
Expand Down
27 changes: 21 additions & 6 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import jax.numpy as jnp
import jax.random as jr
from jax import vmap

import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

from jax.scipy.special import logsumexp

def big_log_joint(initial_probs, transition_matrix, log_likelihoods):
Expand Down Expand Up @@ -56,12 +56,16 @@ def random_hmm_args_nonstationary(key, num_timesteps, num_states, scale=1.0):

# we use numpy so we can assign to the matrix.
# Then we convert to jnp.
keys_A = jr.split(k2, num_timesteps - 1)
trans_mat = jnp.zeros((num_timesteps - 1, num_states, num_states))
for t in range(num_timesteps):
A = jr.uniform(k2, (num_states, num_states))
A /= A.sum(1, keepdims=True)
trans_mat = trans_mat.at[t].set(A)
return initial_probs, jnp.array(trans_mat), log_likelihoods
for t in range(num_timesteps - 1):
A = jr.uniform(keys_A[t], (num_states, num_states))
A /= A.sum(1, keepdims=True)
trans_mat = trans_mat.at[t].set(A)

return initial_probs, trans_mat, log_likelihoods



def test_hmm_filter(key=0, num_timesteps=3, num_states=2):
"""
Expand Down Expand Up @@ -169,6 +173,16 @@ def test_hmm_smoother(key=0, num_timesteps=5, num_states=2):
smoothed_probs_t = jnp.sum(joint, axis=tuple(jnp.arange(t)) + tuple(jnp.arange(t + 1, num_timesteps)))
assert jnp.allclose(posterior.smoothed_probs[t], smoothed_probs_t, atol=1e-4)

def test_two_filter_vs_smoother_nonstationary(key=0, num_timesteps=6, num_states=3):
key = jr.PRNGKey(key)
init, A_t, log_lkhds = random_hmm_args_nonstationary(key, num_timesteps, num_states)

post_two = core.hmm_two_filter_smoother(init, A_t, log_lkhds)
post_rts = core.hmm_smoother(init, A_t, log_lkhds)

assert jnp.allclose(post_two.smoothed_probs, post_rts.smoothed_probs, atol=1e-4)



def test_hmm_fixed_lag_smoother(key=0, num_timesteps=5, num_states=2):
"""
Expand Down Expand Up @@ -351,3 +365,4 @@ def test_parallel_posterior_sample(

# Compare the joint distributions
assert jnp.allclose(blj_sample, blj, rtol=0, atol=eps)

2 changes: 1 addition & 1 deletion dynamax/linear_gaussian_ssm/parallel_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jax import vmap


allclose = partial(jnp.allclose, atol=1e-4)
allclose = partial(jnp.allclose, atol=1e-3)

def make_static_lgssm_params():
"""Create a static LGSSM with fixed parameters."""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ name = "dynamax"
dynamic = ["version"]
requires-python = ">= 3.10"
dependencies = [
"jax>=0.3.15",
"jax",
"jaxlib",
"tfp-nightly",
"fastprogress",
"optax",
"tensorflow_probability",
"scikit-learn",
"jaxtyping",
"typing-extensions",
Expand Down
Loading