Skip to content

Commit 76ea91b

Browse files
add sparse method on pytorch backend
1 parent 783216b commit 76ea91b

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Add `tc.AnalogCircuit` for digital-analog hybrid simulation.
88

9+
- Add sparse matrix related methods for pytorch backend.
10+
911
## v1.4.0
1012

1113
### Added

tensorcircuit/backends/pytorch_backend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from operator import mul
1010
from functools import reduce, partial
1111

12+
from scipy.sparse import coo_matrix
1213
import tensornetwork
1314
from tensornetwork.backends.pytorch import pytorch_backend
1415
from .abstract_backend import ExtendedBackend
@@ -302,6 +303,9 @@ def kron(self, a: Tensor, b: Tensor) -> Tensor:
302303
return torchlib.kron(a, b)
303304

304305
def numpy(self, a: Tensor) -> Tensor:
306+
if self.is_sparse(a):
307+
a = a.coalesce()
308+
return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape)
305309
a = a.cpu()
306310
if a.is_conj():
307311
return a.resolve_conj().numpy()
@@ -381,6 +385,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
381385
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
382386
return torchlib.sort(a, dim=axis).values
383387

388+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
389+
return torchlib.argsort(a, dim=axis)
390+
384391
def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
385392
"""
386393
Corresponds to torch.all.
@@ -467,6 +474,39 @@ def where(
467474
def reverse(self, a: Tensor) -> Tensor:
468475
return torchlib.flip(a, dims=(-1,))
469476

477+
def coo_sparse_matrix(
478+
self, indices: Tensor, values: Tensor, shape: Tensor
479+
) -> Tensor:
480+
# Convert COO format to PyTorch sparse tensor
481+
indices = self.convert_to_tensor(indices)
482+
return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape)
483+
484+
def sparse_dense_matmul(
485+
self,
486+
sp_a: Tensor,
487+
b: Tensor,
488+
) -> Tensor:
489+
# Matrix multiplication between sparse and dense tensor
490+
return torchlib.sparse.mm(sp_a, b)
491+
492+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
493+
try:
494+
# Convert COO to CSR format if supported
495+
return coo.to_sparse_csr()
496+
except AttributeError as e:
497+
if not strict:
498+
return coo
499+
else:
500+
raise e
501+
502+
def to_dense(self, sp_a: Tensor) -> Tensor:
503+
# Convert sparse tensor to dense
504+
return sp_a.to_dense()
505+
506+
def is_sparse(self, a: Tensor) -> bool:
507+
# Check if tensor is sparse
508+
return a.is_sparse or a.is_sparse_csr # type: ignore
509+
470510
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
471511
# torch native tree_map not support multiple pytree args
472512
# return torchlib.utils._pytree.tree_map(f, *pytrees)

tests/test_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def f(x):
6161
np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
6262

6363

64-
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
64+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
6565
def test_sparse_csr_from_coo(backend):
6666
# Create a sparse matrix in COO format
6767
values = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0]))
@@ -583,7 +583,7 @@ def test_arg_cmp(backend):
583583
)
584584

585585

586-
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
586+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
587587
def test_argsort(backend):
588588
# Test basic argsort functionality
589589
a = tc.array_to_tensor(np.array([3, 1, 2]), dtype="float32")
@@ -987,7 +987,7 @@ def grad(A, x):
987987
np.testing.assert_allclose(n_grad, a_grad, atol=1e-3)
988988

989989

990-
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
990+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
991991
def test_sparse_methods(backend):
992992
values = tc.backend.convert_to_tensor(np.array([1.0, 2.0]))
993993
values = tc.backend.cast(values, "complex64")

0 commit comments

Comments
 (0)