Skip to content

Commit 4641994

Browse files
authored
Merge pull request #332 from calebweinreb/lgssm_with_diagonal_noise
LGSSM with diagonal noise
2 parents d19e92a + d326769 commit 4641994

File tree

4 files changed

+291
-94
lines changed

4 files changed

+291
-94
lines changed

dynamax/linear_gaussian_ssm/inference.py

Lines changed: 113 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import jax.numpy as jnp
22
import jax.random as jr
33
from jax import lax
4-
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
54
from functools import wraps
65
import inspect
6+
import warnings
7+
8+
from tensorflow_probability.substrates.jax.distributions import (
9+
MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank,
10+
MultivariateNormalFullCovariance as MVN)
711

812
from jax.tree_util import tree_map
913
from jaxtyping import Array, Float
@@ -41,10 +45,22 @@ class ParamsLGSSMDynamics(NamedTuple):
4145
:param cov: dynamics covariance $Q$
4246
4347
"""
44-
weights: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], ParameterProperties]
45-
bias: Union[Float[Array, "state_dim"], Float[Array, "ntime state_dim"], ParameterProperties]
46-
input_weights: Union[Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"], ParameterProperties]
47-
cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]
48+
weights: Union[ParameterProperties,
49+
Float[Array, "state_dim state_dim"],
50+
Float[Array, "ntime state_dim state_dim"]]
51+
52+
bias: Union[ParameterProperties,
53+
Float[Array, "state_dim"],
54+
Float[Array, "ntime state_dim"]]
55+
56+
input_weights: Union[ParameterProperties,
57+
Float[Array, "state_dim input_dim"],
58+
Float[Array, "ntime state_dim input_dim"]]
59+
60+
cov: Union[ParameterProperties,
61+
Float[Array, "state_dim state_dim"],
62+
Float[Array, "ntime state_dim state_dim"],
63+
Float[Array, "state_dim_triu"]]
4864

4965

5066
class ParamsLGSSMEmissions(NamedTuple):
@@ -60,11 +76,24 @@ class ParamsLGSSMEmissions(NamedTuple):
6076
:param cov: emission covariance $R$
6177
6278
"""
63-
weights: Union[Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"], ParameterProperties]
64-
bias: Union[Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"], ParameterProperties]
65-
input_weights: Union[Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"], ParameterProperties]
66-
cov: Union[Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], Float[Array, "emission_dim_triu"], ParameterProperties]
67-
79+
weights: Union[ParameterProperties,
80+
Float[Array, "emission_dim state_dim"],
81+
Float[Array, "ntime emission_dim state_dim"]]
82+
83+
bias: Union[ParameterProperties,
84+
Float[Array, "emission_dim"],
85+
Float[Array, "ntime emission_dim"]]
86+
87+
input_weights: Union[ParameterProperties,
88+
Float[Array, "emission_dim input_dim"],
89+
Float[Array, "ntime emission_dim input_dim"]]
90+
91+
cov: Union[ParameterProperties,
92+
Float[Array, "emission_dim emission_dim"],
93+
Float[Array, "ntime emission_dim emission_dim"],
94+
Float[Array, "emission_dim"],
95+
Float[Array, "ntime emission_dim"],
96+
Float[Array, "emission_dim_triu"]]
6897

6998

7099
class ParamsLGSSM(NamedTuple):
@@ -115,14 +144,46 @@ class PosteriorGSSMSmoothed(NamedTuple):
115144

116145

117146
# Helper functions
118-
# _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
119-
def _get_params(x, dim, t):
147+
148+
def _get_one_param(x, dim, t):
149+
"""Helper function to get one parameter at time t."""
120150
if callable(x):
121151
return x(t)
122152
elif x.ndim == dim + 1:
123153
return x[t]
124154
else:
125155
return x
156+
157+
def _get_params(params, num_timesteps, t):
158+
"""Helper function to get parameters at time t."""
159+
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable."
160+
161+
F = _get_one_param(params.dynamics.weights, 2, t)
162+
B = _get_one_param(params.dynamics.input_weights, 2, t)
163+
b = _get_one_param(params.dynamics.bias, 1, t)
164+
Q = _get_one_param(params.dynamics.cov, 2, t)
165+
H = _get_one_param(params.emissions.weights, 2, t)
166+
D = _get_one_param(params.emissions.input_weights, 2, t)
167+
d = _get_one_param(params.emissions.bias, 1, t)
168+
169+
if len(params.emissions.cov.shape) == 1:
170+
R = _get_one_param(params.emissions.cov, 1, t)
171+
elif len(params.emissions.cov.shape) > 2:
172+
R = _get_one_param(params.emissions.cov, 2, t)
173+
elif params.emissions.cov.shape[0] != num_timesteps:
174+
R = _get_one_param(params.emissions.cov, 2, t)
175+
elif params.emissions.cov.shape[1] != num_timesteps:
176+
R = _get_one_param(params.emissions.cov, 1, t)
177+
else:
178+
R = _get_one_param(params.emissions.cov, 2, t)
179+
warnings.warn(
180+
"Emission covariance has shape (N,N) where N is the number of timesteps. "
181+
"The covariance will be interpreted as static and non-diagonal. To "
182+
"specify a dynamic and diagonal covariance, pass it as a 3D array.")
183+
184+
return F, B, b, Q, H, D, d, R
185+
186+
126187
_zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape)
127188

128189

@@ -199,7 +260,6 @@ def _condition_on(m, P, H, D, d, R, u, y):
199260
S = (R + H * P * H')
200261
K = P * H' * S^{-1}
201262
PP = P - K S K' = Sigma_cond
202-
**Note! This can be done more efficiently when R is diagonal.**
203263
204264
Args:
205265
m (D_hid,): prior mean.
@@ -215,9 +275,25 @@ def _condition_on(m, P, H, D, d, R, u, y):
215275
mu_pred (D_hid,): predicted mean.
216276
Sigma_pred (D_hid,D_hid): predicted covariance.
217277
"""
218-
# Compute the Kalman gain
219-
S = R + H @ P @ H.T
220-
K = psd_solve(S, H @ P).T
278+
if R.ndim == 2:
279+
S = R + H @ P @ H.T
280+
K = psd_solve(S, H @ P).T
281+
else:
282+
# Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I
283+
# (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity)
284+
I = jnp.eye(P.shape[0])
285+
U = H @ jnp.linalg.cholesky(P)
286+
X = U / R[:, None]
287+
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
288+
"""
289+
# Could alternatively use U=H and C=P
290+
R_inv = jnp.diag(1.0 / R)
291+
P_inv = psd_solve(P, jnp.eye(P.shape[0]))
292+
S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv)
293+
"""
294+
K = P @ H.T @ S_inv
295+
S = jnp.diag(R) + H @ P @ H.T
296+
221297
Sigma_cond = P - K @ S @ K.T
222298
mu_cond = m + K @ (y - D @ u - d - H @ m)
223299
return mu_cond, symmetrize(Sigma_cond)
@@ -285,6 +361,8 @@ def wrapper(*args, **kwargs):
285361
return wrapper
286362

287363

364+
365+
288366
def lgssm_joint_sample(
289367
params: ParamsLGSSM,
290368
key: PRNGKey,
@@ -302,7 +380,6 @@ def lgssm_joint_sample(
302380
latent states and emissions
303381
304382
"""
305-
306383
params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs)
307384

308385
def _sample_transition(key, F, B, b, Q, x_tm1, u):
@@ -311,17 +388,15 @@ def _sample_transition(key, F, B, b, Q, x_tm1, u):
311388

312389
def _sample_emission(key, H, D, d, R, x, u):
313390
mean = H @ x + D @ u + d
391+
R = jnp.diag(R) if R.ndim==1 else R
314392
return MVN(mean, R).sample(seed=key)
315393

316394
def _sample_initial(key, params, inputs):
317395
key1, key2 = jr.split(key)
318396

319397
initial_state = MVN(params.initial.mean, params.initial.cov).sample(seed=key1)
320398

321-
H0 = _get_params(params.emissions.weights, 2, 0)
322-
D0 = _get_params(params.emissions.input_weights, 2, 0)
323-
d0 = _get_params(params.emissions.bias, 1, 0)
324-
R0 = _get_params(params.emissions.cov, 2, 0)
399+
H0, D0, d0, R0 = _get_params(params, num_timesteps, 0)[4:]
325400
u0 = tree_map(lambda x: x[0], inputs)
326401

327402
initial_emission = _sample_emission(key2, H0, D0, d0, R0, initial_state, u0)
@@ -331,15 +406,8 @@ def _step(prev_state, args):
331406
key, t, inpt = args
332407
key1, key2 = jr.split(key, 2)
333408

334-
# Shorthand: get parameters and inputs for time index t
335-
F = _get_params(params.dynamics.weights, 2, t)
336-
B = _get_params(params.dynamics.input_weights, 2, t)
337-
b = _get_params(params.dynamics.bias, 1, t)
338-
Q = _get_params(params.dynamics.cov, 2, t)
339-
H = _get_params(params.emissions.weights, 2, t)
340-
D = _get_params(params.emissions.input_weights, 2, t)
341-
d = _get_params(params.emissions.bias, 1, t)
342-
R = _get_params(params.emissions.cov, 2, t)
409+
# Get parameters and inputs for time index t
410+
F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t)
343411

344412
# Sample from transition and emission distributions
345413
state = _sample_transition(key1, F, B, b, Q, prev_state, inpt)
@@ -386,23 +454,26 @@ def lgssm_filter(
386454
num_timesteps = len(emissions)
387455
inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs
388456

457+
def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y):
458+
m = H @ pred_mean + D @ u + d
459+
if R.ndim==2:
460+
S = R + H @ pred_cov @ H.T
461+
return MVN(m, S).log_prob(y)
462+
else:
463+
L = H @ jnp.linalg.cholesky(pred_cov)
464+
return MVNLowRank(m, R, L).log_prob(y)
465+
466+
389467
def _step(carry, t):
390468
ll, pred_mean, pred_cov = carry
391469

392470
# Shorthand: get parameters and inputs for time index t
393-
F = _get_params(params.dynamics.weights, 2, t)
394-
B = _get_params(params.dynamics.input_weights, 2, t)
395-
b = _get_params(params.dynamics.bias, 1, t)
396-
Q = _get_params(params.dynamics.cov, 2, t)
397-
H = _get_params(params.emissions.weights, 2, t)
398-
D = _get_params(params.emissions.input_weights, 2, t)
399-
d = _get_params(params.emissions.bias, 1, t)
400-
R = _get_params(params.emissions.cov, 2, t)
471+
F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t)
401472
u = inputs[t]
402473
y = emissions[t]
403474

404475
# Update the log likelihood
405-
ll += MVN(H @ pred_mean + D @ u + d, H @ pred_cov @ H.T + R).log_prob(y)
476+
ll += _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y)
406477

407478
# Condition on this emission
408479
filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, H, D, d, R, u, y)
@@ -450,11 +521,8 @@ def _step(carry, args):
450521
smoothed_mean_next, smoothed_cov_next = carry
451522
t, filtered_mean, filtered_cov = args
452523

453-
# Shorthand: get parameters and inputs for time index t
454-
F = _get_params(params.dynamics.weights, 2, t)
455-
B = _get_params(params.dynamics.input_weights, 2, t)
456-
b = _get_params(params.dynamics.bias, 1, t)
457-
Q = _get_params(params.dynamics.cov, 2, t)
524+
# Get parameters and inputs for time index t
525+
F, B, b, Q = _get_params(params, num_timesteps, t)[:4]
458526
u = inputs[t]
459527

460528
# This is like the Kalman gain but in reverse
@@ -522,10 +590,7 @@ def _step(carry, args):
522590
key, filtered_mean, filtered_cov, t = args
523591

524592
# Shorthand: get parameters and inputs for time index t
525-
F = _get_params(params.dynamics.weights, 2, t)
526-
B = _get_params(params.dynamics.input_weights, 2, t)
527-
b = _get_params(params.dynamics.bias, 1, t)
528-
Q = _get_params(params.dynamics.cov, 2, t)
593+
F, B, b, Q = _get_params(params, num_timesteps, t)[:4]
529594
u = inputs[t]
530595

531596
# Condition on next state

0 commit comments

Comments
 (0)