File tree Expand file tree Collapse file tree 2 files changed +168
-34
lines changed Expand file tree Collapse file tree 2 files changed +168
-34
lines changed Original file line number Diff line number Diff line change 66import pathlib
77import shutil
88import tempfile
9+ import warnings
910
1011import numpy as np
1112import torch
@@ -32,6 +33,10 @@ def prod(sequence):
3233 return int (np .prod (sequence ))
3334
3435
36+ def is_npu_available ():
37+ return hasattr (torch , "npu" ) and torch .npu .is_available ()
38+
39+
3540def get_available_devices ():
3641 devices = [torch .device ("cpu" )]
3742 n_cuda = torch .cuda .device_count ()
@@ -40,6 +45,16 @@ def get_available_devices():
4045 devices += [torch .device (f"cuda:{ i } " )]
4146 if i == 1 :
4247 break
48+ if is_npu_available ():
49+ warnings .warn (
50+ "torch_npu is an experimental feature and not currently included in tensordict CI/CD."
51+ )
52+ n_npu = torch .npu .device_count ()
53+ if n_npu > 0 :
54+ for i in range (n_npu ):
55+ devices += [torch .device (f"npu:{ i } " )]
56+ if i == 1 :
57+ break
4358 # TODO: MPS and NPU would be worth considering but it's a lot of work
4459 # for example, many ops are tested with various dtypes but MPS struggles with
4560 # float64. Shared mem can also cause trouble.
You can’t perform that action at this time.
0 commit comments