Skip to content

Commit cadec55

Browse files
authored
Merge pull request #335 from calebweinreb/stable_psd_solve
Stable psd solve
2 parents 4641994 + 44d8c7d commit cadec55

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

dynamax/utils/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from jaxtyping import Array, Int
1010
from scipy.optimize import linear_sum_assignment
1111
from typing import Optional
12+
from jax.scipy.linalg import cho_factor, cho_solve
1213

1314
def has_tpu():
1415
try:
@@ -198,10 +199,12 @@ def find_permutation(
198199
return perm
199200

200201

201-
def psd_solve(A,b):
202+
def psd_solve(A, b, diagonal_boost=1e-9):
202203
"""A wrapper for coordinating the linalg solvers used in the library for psd matrices."""
203-
A = A + 1e-6
204-
return jnp.linalg.solve(A,b)
204+
A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1])
205+
L, lower = cho_factor(A, lower=True)
206+
x = cho_solve((L, lower), b)
207+
return x
205208

206209
def symmetrize(A):
207210
"""Symmetrize one or more matrices."""

0 commit comments

Comments
 (0)