11import jax .numpy as jnp
22import jax .random as jr
33from jax import lax
4- from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
54from functools import wraps
65import inspect
6+ import warnings
7+
8+ from tensorflow_probability .substrates .jax .distributions import (
9+ MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank ,
10+ MultivariateNormalFullCovariance as MVN )
711
812from jax .tree_util import tree_map
913from jaxtyping import Array , Float
@@ -41,10 +45,22 @@ class ParamsLGSSMDynamics(NamedTuple):
4145 :param cov: dynamics covariance $Q$
4246
4347 """
44- weights : Union [Float [Array , "state_dim state_dim" ], Float [Array , "ntime state_dim state_dim" ], ParameterProperties ]
45- bias : Union [Float [Array , "state_dim" ], Float [Array , "ntime state_dim" ], ParameterProperties ]
46- input_weights : Union [Float [Array , "state_dim input_dim" ], Float [Array , "ntime state_dim input_dim" ], ParameterProperties ]
47- cov : Union [Float [Array , "state_dim state_dim" ], Float [Array , "ntime state_dim state_dim" ], Float [Array , "state_dim_triu" ], ParameterProperties ]
48+ weights : Union [ParameterProperties ,
49+ Float [Array , "state_dim state_dim" ],
50+ Float [Array , "ntime state_dim state_dim" ]]
51+
52+ bias : Union [ParameterProperties ,
53+ Float [Array , "state_dim" ],
54+ Float [Array , "ntime state_dim" ]]
55+
56+ input_weights : Union [ParameterProperties ,
57+ Float [Array , "state_dim input_dim" ],
58+ Float [Array , "ntime state_dim input_dim" ]]
59+
60+ cov : Union [ParameterProperties ,
61+ Float [Array , "state_dim state_dim" ],
62+ Float [Array , "ntime state_dim state_dim" ],
63+ Float [Array , "state_dim_triu" ]]
4864
4965
5066class ParamsLGSSMEmissions (NamedTuple ):
@@ -60,11 +76,24 @@ class ParamsLGSSMEmissions(NamedTuple):
6076 :param cov: emission covariance $R$
6177
6278 """
63- weights : Union [Float [Array , "emission_dim state_dim" ], Float [Array , "ntime emission_dim state_dim" ], ParameterProperties ]
64- bias : Union [Float [Array , "emission_dim" ], Float [Array , "ntime emission_dim" ], ParameterProperties ]
65- input_weights : Union [Float [Array , "emission_dim input_dim" ], Float [Array , "ntime emission_dim input_dim" ], ParameterProperties ]
66- cov : Union [Float [Array , "emission_dim emission_dim" ], Float [Array , "ntime emission_dim emission_dim" ], Float [Array , "emission_dim_triu" ], ParameterProperties ]
67-
79+ weights : Union [ParameterProperties ,
80+ Float [Array , "emission_dim state_dim" ],
81+ Float [Array , "ntime emission_dim state_dim" ]]
82+
83+ bias : Union [ParameterProperties ,
84+ Float [Array , "emission_dim" ],
85+ Float [Array , "ntime emission_dim" ]]
86+
87+ input_weights : Union [ParameterProperties ,
88+ Float [Array , "emission_dim input_dim" ],
89+ Float [Array , "ntime emission_dim input_dim" ]]
90+
91+ cov : Union [ParameterProperties ,
92+ Float [Array , "emission_dim emission_dim" ],
93+ Float [Array , "ntime emission_dim emission_dim" ],
94+ Float [Array , "emission_dim" ],
95+ Float [Array , "ntime emission_dim" ],
96+ Float [Array , "emission_dim_triu" ]]
6897
6998
7099class ParamsLGSSM (NamedTuple ):
@@ -115,14 +144,46 @@ class PosteriorGSSMSmoothed(NamedTuple):
115144
116145
117146# Helper functions
118- # _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
119- def _get_params (x , dim , t ):
147+
148+ def _get_one_param (x , dim , t ):
149+ """Helper function to get one parameter at time t."""
120150 if callable (x ):
121151 return x (t )
122152 elif x .ndim == dim + 1 :
123153 return x [t ]
124154 else :
125155 return x
156+
157+ def _get_params (params , num_timesteps , t ):
158+ """Helper function to get parameters at time t."""
159+ assert not callable (params .emissions .cov ), "Emission covariance cannot be a callable."
160+
161+ F = _get_one_param (params .dynamics .weights , 2 , t )
162+ B = _get_one_param (params .dynamics .input_weights , 2 , t )
163+ b = _get_one_param (params .dynamics .bias , 1 , t )
164+ Q = _get_one_param (params .dynamics .cov , 2 , t )
165+ H = _get_one_param (params .emissions .weights , 2 , t )
166+ D = _get_one_param (params .emissions .input_weights , 2 , t )
167+ d = _get_one_param (params .emissions .bias , 1 , t )
168+
169+ if len (params .emissions .cov .shape ) == 1 :
170+ R = _get_one_param (params .emissions .cov , 1 , t )
171+ elif len (params .emissions .cov .shape ) > 2 :
172+ R = _get_one_param (params .emissions .cov , 2 , t )
173+ elif params .emissions .cov .shape [0 ] != num_timesteps :
174+ R = _get_one_param (params .emissions .cov , 2 , t )
175+ elif params .emissions .cov .shape [1 ] != num_timesteps :
176+ R = _get_one_param (params .emissions .cov , 1 , t )
177+ else :
178+ R = _get_one_param (params .emissions .cov , 2 , t )
179+ warnings .warn (
180+ "Emission covariance has shape (N,N) where N is the number of timesteps. "
181+ "The covariance will be interpreted as static and non-diagonal. To "
182+ "specify a dynamic and diagonal covariance, pass it as a 3D array." )
183+
184+ return F , B , b , Q , H , D , d , R
185+
186+
126187_zeros_if_none = lambda x , shape : x if x is not None else jnp .zeros (shape )
127188
128189
@@ -199,7 +260,6 @@ def _condition_on(m, P, H, D, d, R, u, y):
199260 S = (R + H * P * H')
200261 K = P * H' * S^{-1}
201262 PP = P - K S K' = Sigma_cond
202- **Note! This can be done more efficiently when R is diagonal.**
203263
204264 Args:
205265 m (D_hid,): prior mean.
@@ -215,9 +275,25 @@ def _condition_on(m, P, H, D, d, R, u, y):
215275 mu_pred (D_hid,): predicted mean.
216276 Sigma_pred (D_hid,D_hid): predicted covariance.
217277 """
218- # Compute the Kalman gain
219- S = R + H @ P @ H .T
220- K = psd_solve (S , H @ P ).T
278+ if R .ndim == 2 :
279+ S = R + H @ P @ H .T
280+ K = psd_solve (S , H @ P ).T
281+ else :
282+ # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I
283+ # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity)
284+ I = jnp .eye (P .shape [0 ])
285+ U = H @ jnp .linalg .cholesky (P )
286+ X = U / R [:, None ]
287+ S_inv = jnp .diag (1.0 / R ) - X @ psd_solve (I + U .T @ X , X .T )
288+ """
289+ # Could alternatively use U=H and C=P
290+ R_inv = jnp.diag(1.0 / R)
291+ P_inv = psd_solve(P, jnp.eye(P.shape[0]))
292+ S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv)
293+ """
294+ K = P @ H .T @ S_inv
295+ S = jnp .diag (R ) + H @ P @ H .T
296+
221297 Sigma_cond = P - K @ S @ K .T
222298 mu_cond = m + K @ (y - D @ u - d - H @ m )
223299 return mu_cond , symmetrize (Sigma_cond )
@@ -285,6 +361,8 @@ def wrapper(*args, **kwargs):
285361 return wrapper
286362
287363
364+
365+
288366def lgssm_joint_sample (
289367 params : ParamsLGSSM ,
290368 key : PRNGKey ,
@@ -302,7 +380,6 @@ def lgssm_joint_sample(
302380 latent states and emissions
303381
304382 """
305-
306383 params , inputs = preprocess_params_and_inputs (params , num_timesteps , inputs )
307384
308385 def _sample_transition (key , F , B , b , Q , x_tm1 , u ):
@@ -311,17 +388,15 @@ def _sample_transition(key, F, B, b, Q, x_tm1, u):
311388
312389 def _sample_emission (key , H , D , d , R , x , u ):
313390 mean = H @ x + D @ u + d
391+ R = jnp .diag (R ) if R .ndim == 1 else R
314392 return MVN (mean , R ).sample (seed = key )
315393
316394 def _sample_initial (key , params , inputs ):
317395 key1 , key2 = jr .split (key )
318396
319397 initial_state = MVN (params .initial .mean , params .initial .cov ).sample (seed = key1 )
320398
321- H0 = _get_params (params .emissions .weights , 2 , 0 )
322- D0 = _get_params (params .emissions .input_weights , 2 , 0 )
323- d0 = _get_params (params .emissions .bias , 1 , 0 )
324- R0 = _get_params (params .emissions .cov , 2 , 0 )
399+ H0 , D0 , d0 , R0 = _get_params (params , num_timesteps , 0 )[4 :]
325400 u0 = tree_map (lambda x : x [0 ], inputs )
326401
327402 initial_emission = _sample_emission (key2 , H0 , D0 , d0 , R0 , initial_state , u0 )
@@ -331,15 +406,8 @@ def _step(prev_state, args):
331406 key , t , inpt = args
332407 key1 , key2 = jr .split (key , 2 )
333408
334- # Shorthand: get parameters and inputs for time index t
335- F = _get_params (params .dynamics .weights , 2 , t )
336- B = _get_params (params .dynamics .input_weights , 2 , t )
337- b = _get_params (params .dynamics .bias , 1 , t )
338- Q = _get_params (params .dynamics .cov , 2 , t )
339- H = _get_params (params .emissions .weights , 2 , t )
340- D = _get_params (params .emissions .input_weights , 2 , t )
341- d = _get_params (params .emissions .bias , 1 , t )
342- R = _get_params (params .emissions .cov , 2 , t )
409+ # Get parameters and inputs for time index t
410+ F , B , b , Q , H , D , d , R = _get_params (params , num_timesteps , t )
343411
344412 # Sample from transition and emission distributions
345413 state = _sample_transition (key1 , F , B , b , Q , prev_state , inpt )
@@ -386,23 +454,26 @@ def lgssm_filter(
386454 num_timesteps = len (emissions )
387455 inputs = jnp .zeros ((num_timesteps , 0 )) if inputs is None else inputs
388456
457+ def _log_likelihood (pred_mean , pred_cov , H , D , d , R , u , y ):
458+ m = H @ pred_mean + D @ u + d
459+ if R .ndim == 2 :
460+ S = R + H @ pred_cov @ H .T
461+ return MVN (m , S ).log_prob (y )
462+ else :
463+ L = H @ jnp .linalg .cholesky (pred_cov )
464+ return MVNLowRank (m , R , L ).log_prob (y )
465+
466+
389467 def _step (carry , t ):
390468 ll , pred_mean , pred_cov = carry
391469
392470 # Shorthand: get parameters and inputs for time index t
393- F = _get_params (params .dynamics .weights , 2 , t )
394- B = _get_params (params .dynamics .input_weights , 2 , t )
395- b = _get_params (params .dynamics .bias , 1 , t )
396- Q = _get_params (params .dynamics .cov , 2 , t )
397- H = _get_params (params .emissions .weights , 2 , t )
398- D = _get_params (params .emissions .input_weights , 2 , t )
399- d = _get_params (params .emissions .bias , 1 , t )
400- R = _get_params (params .emissions .cov , 2 , t )
471+ F , B , b , Q , H , D , d , R = _get_params (params , num_timesteps , t )
401472 u = inputs [t ]
402473 y = emissions [t ]
403474
404475 # Update the log likelihood
405- ll += MVN ( H @ pred_mean + D @ u + d , H @ pred_cov @ H . T + R ). log_prob ( y )
476+ ll += _log_likelihood ( pred_mean , pred_cov , H , D , d , R , u , y )
406477
407478 # Condition on this emission
408479 filtered_mean , filtered_cov = _condition_on (pred_mean , pred_cov , H , D , d , R , u , y )
@@ -450,11 +521,8 @@ def _step(carry, args):
450521 smoothed_mean_next , smoothed_cov_next = carry
451522 t , filtered_mean , filtered_cov = args
452523
453- # Shorthand: get parameters and inputs for time index t
454- F = _get_params (params .dynamics .weights , 2 , t )
455- B = _get_params (params .dynamics .input_weights , 2 , t )
456- b = _get_params (params .dynamics .bias , 1 , t )
457- Q = _get_params (params .dynamics .cov , 2 , t )
524+ # Get parameters and inputs for time index t
525+ F , B , b , Q = _get_params (params , num_timesteps , t )[:4 ]
458526 u = inputs [t ]
459527
460528 # This is like the Kalman gain but in reverse
@@ -522,10 +590,7 @@ def _step(carry, args):
522590 key , filtered_mean , filtered_cov , t = args
523591
524592 # Shorthand: get parameters and inputs for time index t
525- F = _get_params (params .dynamics .weights , 2 , t )
526- B = _get_params (params .dynamics .input_weights , 2 , t )
527- b = _get_params (params .dynamics .bias , 1 , t )
528- Q = _get_params (params .dynamics .cov , 2 , t )
593+ F , B , b , Q = _get_params (params , num_timesteps , t )[:4 ]
529594 u = inputs [t ]
530595
531596 # Condition on next state
0 commit comments