Skip to content

Commit b59ca7c

Browse files
UserChen666chenhao388vmoens
authored
[Feature] Extended support to include NPU use cases in addition to existing CUDA-compatible scenarios (#1460)
Co-authored-by: chenhao388 <[email protected]> Co-authored-by: Vincent Moens <[email protected]>
1 parent 3e856cb commit b59ca7c

File tree

2 files changed

+168
-34
lines changed

2 files changed

+168
-34
lines changed

test/_utils_internal.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pathlib
77
import shutil
88
import tempfile
9+
import warnings
910

1011
import numpy as np
1112
import torch
@@ -32,6 +33,10 @@ def prod(sequence):
3233
return int(np.prod(sequence))
3334

3435

36+
def is_npu_available():
37+
return hasattr(torch, "npu") and torch.npu.is_available()
38+
39+
3540
def get_available_devices():
3641
devices = [torch.device("cpu")]
3742
n_cuda = torch.cuda.device_count()
@@ -40,6 +45,16 @@ def get_available_devices():
4045
devices += [torch.device(f"cuda:{i}")]
4146
if i == 1:
4247
break
48+
if is_npu_available():
49+
warnings.warn(
50+
"torch_npu is an experimental feature and not currently included in tensordict CI/CD."
51+
)
52+
n_npu = torch.npu.device_count()
53+
if n_npu > 0:
54+
for i in range(n_npu):
55+
devices += [torch.device(f"npu:{i}")]
56+
if i == 1:
57+
break
4358
# TODO: MPS and NPU would be worth considering but it's a lot of work
4459
# for example, many ops are tested with various dtypes but MPS struggles with
4560
# float64. Shared mem can also cause trouble.

0 commit comments

Comments
 (0)