File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change 99from jaxtyping import Array , Int
1010from scipy .optimize import linear_sum_assignment
1111from typing import Optional
12+ from jax .scipy .linalg import cho_factor , cho_solve
1213
1314def has_tpu ():
1415 try :
@@ -198,10 +199,12 @@ def find_permutation(
198199 return perm
199200
200201
201- def psd_solve (A ,b ):
202+ def psd_solve (A , b , diagonal_boost = 1e-9 ):
202203 """A wrapper for coordinating the linalg solvers used in the library for psd matrices."""
203- A = A + 1e-6
204- return jnp .linalg .solve (A ,b )
204+ A = symmetrize (A ) + diagonal_boost * jnp .eye (A .shape [- 1 ])
205+ L , lower = cho_factor (A , lower = True )
206+ x = cho_solve ((L , lower ), b )
207+ return x
205208
206209def symmetrize (A ):
207210 """Symmetrize one or more matrices."""
You can’t perform that action at this time.
0 commit comments