Skip to content

Commit 2c86497

Browse files
authored
Add scipy.linalg.lu() decomposition support (#2787)
This PR adds `dpnp.scipy.linalg.lu()` with support for all three output modes: default `(P, L, U)`, `permute_l=True (PL, U)`, and `p_indices=True` `(p, L, U)`, including batched inputs. Fixes: #2786
1 parent 7803d3a commit 2c86497

File tree

8 files changed

+886
-15
lines changed

8 files changed

+886
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
2626
* Added implementation of `dpnp.ndarray.__bytes__` method [#2671](https://github.com/IntelPython/dpnp/pull/2671)
2727
* Added implementation of `dpnp.divmod` [#2674](https://github.com/IntelPython/dpnp/pull/2674)
2828
* Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595)
29+
* Added implementation of `dpnp.scipy.linalg.lu` (SciPy-compatible) [#2787](https://github.com/IntelPython/dpnp/pull/2787)
2930

3031
### Changed
3132

dpnp/scipy/linalg/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
3636
"""
3737

38-
from ._decomp_lu import lu_factor, lu_solve
38+
from ._decomp_lu import lu, lu_factor, lu_solve
3939

4040
__all__ = [
41+
"lu",
4142
"lu_factor",
4243
"lu_solve",
4344
]

dpnp/scipy/linalg/_decomp_lu.py

Lines changed: 148 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,154 @@
4646
)
4747

4848
from ._utils import (
49+
dpnp_lu,
4950
dpnp_lu_factor,
5051
dpnp_lu_solve,
5152
)
5253

5354

55+
def lu(
56+
a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
57+
):
58+
"""
59+
Compute LU decomposition of a matrix with partial pivoting.
60+
61+
The decomposition satisfies::
62+
63+
A = P @ L @ U
64+
65+
where `P` is a permutation matrix, `L` is lower triangular with unit
66+
diagonal elements, and `U` is upper triangular. If `permute_l` is set to
67+
``True`` then `L` is returned already permuted and hence satisfying
68+
``A = L @ U``.
69+
70+
For full documentation refer to :obj:`scipy.linalg.lu`.
71+
72+
Parameters
73+
----------
74+
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
75+
Input array to decompose.
76+
permute_l : bool, optional
77+
Perform the multiplication ``P @ L`` (Default: do not permute).
78+
79+
Default: ``False``.
80+
overwrite_a : {None, bool}, optional
81+
Whether to overwrite data in `a` (may increase performance).
82+
83+
Default: ``False``.
84+
check_finite : {None, bool}, optional
85+
Whether to check that the input matrix contains only finite numbers.
86+
Disabling may give a performance gain, but may result in problems
87+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
88+
89+
Default: ``True``.
90+
p_indices : bool, optional
91+
If ``True`` the permutation information is returned as row indices
92+
instead of a permutation matrix.
93+
94+
Default: ``False``.
95+
96+
Returns
97+
-------
98+
**(If ``permute_l`` is ``False``)**
99+
100+
p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray
101+
If `p_indices` is ``False`` (default), the permutation matrix.
102+
The permutation matrix always has a real dtype (``float32`` or
103+
``float64``) even when `a` is complex, since it only contains
104+
0s and 1s.
105+
If `p_indices` is ``True``, a 1-D (or batched) array of row
106+
permutation indices such that ``A = L[p] @ U``.
107+
l : (..., M, K) dpnp.ndarray
108+
Lower triangular or trapezoidal matrix with unit diagonal.
109+
``K = min(M, N)``.
110+
u : (..., K, N) dpnp.ndarray
111+
Upper triangular or trapezoidal matrix.
112+
113+
**(If ``permute_l`` is ``True``)**
114+
115+
pl : (..., M, K) dpnp.ndarray
116+
Permuted ``L`` matrix: ``pl = P @ L``.
117+
``K = min(M, N)``.
118+
u : (..., K, N) dpnp.ndarray
119+
Upper triangular or trapezoidal matrix.
120+
121+
Notes
122+
-----
123+
Permutation matrices are costly since they are nothing but row reorder of
124+
``L`` and hence indices are strongly recommended to be used instead if the
125+
permutation is required. The relation in the 2D case then becomes simply
126+
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
127+
to avoid complicated indexing tricks.
128+
129+
In the 2D case, if one has the indices however, for some reason, the
130+
permutation matrix is still needed then it can be constructed by
131+
``dpnp.eye(M)[P, :]``.
132+
133+
Warning
134+
-------
135+
This function synchronizes in order to validate array elements
136+
when ``check_finite=True``, and also synchronizes to compute the
137+
permutation from LAPACK pivot indices.
138+
139+
See Also
140+
--------
141+
:obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix
142+
(compact representation).
143+
:obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using
144+
the LU factorization of a matrix.
145+
146+
Examples
147+
--------
148+
>>> import dpnp as np
149+
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8],
150+
... [7, 5, 6, 6], [5, 4, 4, 8]])
151+
>>> p, l, u = np.scipy.linalg.lu(A)
152+
>>> np.allclose(A, p @ l @ u)
153+
array(True)
154+
155+
Retrieve the permutation as row indices with ``p_indices=True``:
156+
157+
>>> p, l, u = np.scipy.linalg.lu(A, p_indices=True)
158+
>>> p
159+
array([1, 3, 0, 2])
160+
>>> np.allclose(A, l[p] @ u)
161+
array(True)
162+
163+
Return the permuted ``L`` directly with ``permute_l=True``:
164+
165+
>>> pl, u = np.scipy.linalg.lu(A, permute_l=True)
166+
>>> np.allclose(A, pl @ u)
167+
array(True)
168+
169+
Non-square matrices are supported:
170+
171+
>>> B = np.array([[1, 2, 3], [4, 5, 6]])
172+
>>> p, l, u = np.scipy.linalg.lu(B)
173+
>>> np.allclose(B, p @ l @ u)
174+
array(True)
175+
176+
Batched input:
177+
178+
>>> C = np.random.randn(3, 2, 4, 4)
179+
>>> p, l, u = np.scipy.linalg.lu(C)
180+
>>> np.allclose(C, p @ l @ u)
181+
array(True)
182+
183+
"""
184+
185+
dpnp.check_supported_arrays_type(a)
186+
assert_stacked_2d(a)
187+
188+
return dpnp_lu(
189+
a,
190+
overwrite_a=overwrite_a,
191+
check_finite=check_finite,
192+
p_indices=p_indices,
193+
permute_l=permute_l,
194+
)
195+
196+
54197
def lu_factor(a, overwrite_a=False, check_finite=True):
55198
"""
56199
Compute the pivoted LU decomposition of `a` matrix.
@@ -180,13 +323,13 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
180323
181324
"""
182325

183-
lu, piv = lu_and_piv
184-
dpnp.check_supported_arrays_type(lu, piv, b)
185-
assert_stacked_2d(lu)
186-
assert_stacked_square(lu)
326+
lu_matrix, piv = lu_and_piv
327+
dpnp.check_supported_arrays_type(lu_matrix, piv, b)
328+
assert_stacked_2d(lu_matrix)
329+
assert_stacked_square(lu_matrix)
187330

188331
return dpnp_lu_solve(
189-
lu,
332+
lu_matrix,
190333
piv,
191334
b,
192335
trans=trans,

0 commit comments

Comments
 (0)