Skip to content

Commit 554f009

Browse files
committed
Disable use_cuda for local_sgd_integ_tests
1 parent 3724f7c commit 554f009

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchft.local_sgd import DiLoCo, LocalSGD
1717
from torchft.manager import Manager
1818
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
19-
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
19+
from torchft.process_group import ProcessGroupGloo, ProcessGroupBabyNCCL
2020

2121
logger: logging.Logger = logging.getLogger(__name__)
2222

@@ -197,9 +197,11 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
197197

198198

199199
class LocalSGDIntegTest(TestCase):
200+
# TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
201+
# Because of that the test is disabled for cuda
200202
@parameterized.expand(
201203
[
202-
(True,),
204+
# (True,),
203205
(False,),
204206
]
205207
)
@@ -259,7 +261,7 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
259261

260262
@parameterized.expand(
261263
[
262-
(True,),
264+
# (True,),
263265
(False,),
264266
]
265267
)
@@ -319,7 +321,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
319321

320322
@parameterized.expand(
321323
[
322-
(True,),
324+
# (True,),
323325
(False,),
324326
]
325327
)

0 commit comments

Comments
 (0)