|
9 | 9 | from operator import mul
|
10 | 10 | from functools import reduce, partial
|
11 | 11 |
|
| 12 | +from scipy.sparse import coo_matrix |
12 | 13 | import tensornetwork
|
13 | 14 | from tensornetwork.backends.pytorch import pytorch_backend
|
14 | 15 | from .abstract_backend import ExtendedBackend
|
@@ -302,6 +303,9 @@ def kron(self, a: Tensor, b: Tensor) -> Tensor:
|
302 | 303 | return torchlib.kron(a, b)
|
303 | 304 |
|
304 | 305 | def numpy(self, a: Tensor) -> Tensor:
|
| 306 | + if self.is_sparse(a): |
| 307 | + a = a.coalesce() |
| 308 | + return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape) |
305 | 309 | a = a.cpu()
|
306 | 310 | if a.is_conj():
|
307 | 311 | return a.resolve_conj().numpy()
|
@@ -381,6 +385,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
381 | 385 | def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
382 | 386 | return torchlib.sort(a, dim=axis).values
|
383 | 387 |
|
| 388 | + def argsort(self, a: Tensor, axis: int = -1) -> Tensor: |
| 389 | + return torchlib.argsort(a, dim=axis) |
| 390 | + |
384 | 391 | def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
385 | 392 | """
|
386 | 393 | Corresponds to torch.all.
|
@@ -467,6 +474,39 @@ def where(
|
467 | 474 | def reverse(self, a: Tensor) -> Tensor:
|
468 | 475 | return torchlib.flip(a, dims=(-1,))
|
469 | 476 |
|
| 477 | + def coo_sparse_matrix( |
| 478 | + self, indices: Tensor, values: Tensor, shape: Tensor |
| 479 | + ) -> Tensor: |
| 480 | + # Convert COO format to PyTorch sparse tensor |
| 481 | + indices = self.convert_to_tensor(indices) |
| 482 | + return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape) |
| 483 | + |
| 484 | + def sparse_dense_matmul( |
| 485 | + self, |
| 486 | + sp_a: Tensor, |
| 487 | + b: Tensor, |
| 488 | + ) -> Tensor: |
| 489 | + # Matrix multiplication between sparse and dense tensor |
| 490 | + return torchlib.sparse.mm(sp_a, b) |
| 491 | + |
| 492 | + def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor: |
| 493 | + try: |
| 494 | + # Convert COO to CSR format if supported |
| 495 | + return coo.to_sparse_csr() |
| 496 | + except AttributeError as e: |
| 497 | + if not strict: |
| 498 | + return coo |
| 499 | + else: |
| 500 | + raise e |
| 501 | + |
| 502 | + def to_dense(self, sp_a: Tensor) -> Tensor: |
| 503 | + # Convert sparse tensor to dense |
| 504 | + return sp_a.to_dense() |
| 505 | + |
| 506 | + def is_sparse(self, a: Tensor) -> bool: |
| 507 | + # Check if tensor is sparse |
| 508 | + return a.is_sparse or a.is_sparse_csr # type: ignore |
| 509 | + |
470 | 510 | def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
|
471 | 511 | # torch native tree_map not support multiple pytree args
|
472 | 512 | # return torchlib.utils._pytree.tree_map(f, *pytrees)
|
|
0 commit comments