diff --git a/paconvert/global_var.py b/paconvert/global_var.py index 25175e61b..03b8b9bf9 100644 --- a/paconvert/global_var.py +++ b/paconvert/global_var.py @@ -684,16 +684,10 @@ class GlobalManager: "torch.scatter", # xiangyu - # "torch.cuda.current_device", - # "torch.cuda.device_count", - # "torch.cuda.empty_cache", "torch.cuda.get_device_properties", "torch.cuda.get_rng_state", "torch.cuda.is_current_stream_capturing", "torch.cuda.manual_seed_all", - # "torch.cuda.memory_allocated", - # "torch.cuda.memory_reserved", - # "torch.cuda.set_device", "torch.cuda.set_rng_state", # "torch.get_default_device", "torch.get_device_module", @@ -701,4 +695,13 @@ class GlobalManager: "torch.device", #Additional additions "torch.cuda.get_device_capability", + + # geyuqiang + "torch.cuda.current_device", + "torch.cuda.device_count", + "torch.cuda.empty_cache", + "torch.cuda.memory_allocated", + "torch.cuda.memory_reserved", + "torch.cuda.set_device", + "torch.cuda.current_stream", ] diff --git a/tests/test_cuda_current_device.py b/tests/test_cuda_current_device.py index 27ba4a24c..53583be69 100644 --- a/tests/test_cuda_current_device.py +++ b/tests/test_cuda_current_device.py @@ -32,7 +32,7 @@ def compare( rtol=1.0e-6, atol=0.0, ): - assert pytorch_result == int(paddle_result.replace("gpu:", "")) + assert pytorch_result == paddle_result obj = CudaGetDeviceAPIBase("torch.cuda.current_device") diff --git a/tests/test_cuda_set_device.py b/tests/test_cuda_set_device.py index 8b27f4463..0fd94a9c8 100644 --- a/tests/test_cuda_set_device.py +++ b/tests/test_cuda_set_device.py @@ -29,7 +29,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - torch.cuda.set_device(0) + torch.cuda.set_device("cuda:0") result = torch.cuda.current_device() """ )