22import logging
33import threading
44import time
5+ import traceback
56from concurrent .futures import ThreadPoolExecutor , as_completed
67from contextlib import ExitStack , contextmanager
78from dataclasses import dataclass , field
89from datetime import timedelta
9- from typing import Any , Dict , Generator , List , Protocol , Set , Tuple
10+ from typing import Any , Dict , Generator , List , Optional , Protocol , Set , Tuple , TypeVar
1011from unittest import TestCase
1112
1213import torch
1314import torch .distributed as dist
1415from parameterized import parameterized
1516from torch import nn , optim
17+ from torch ._dynamo .utils import timed
1618
1719from torchft ._torchft import LighthouseServer
1820from torchft .ddp import DistributedDataParallel
1921from torchft .local_sgd import DiLoCo , LocalSGD
2022from torchft .manager import Manager
2123from torchft .optim import OptimizerWrapper
22- from torchft .process_group import ProcessGroupGloo
24+ from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
2325
2426logger : logging .Logger = logging .getLogger (__name__ )
2527
@@ -69,10 +71,14 @@ def check(self, rank: int, step: int) -> None:
6971 raise InjectedFailure (f"injected failure { rank = } { step = } " )
7072
7173
72- class TrainLoop (Protocol ):
74+ # R for an arbitrary return type
75+ R = TypeVar ("R" , covariant = True )
76+
77+
78+ class TrainLoop (Protocol [R ]):
7379 def __call__ (
7480 self , rank : int , store_port : int , device : torch .device , runner : "Runner"
75- ) -> Dict [ str , Dict [ str , object ]] : ...
81+ ) -> R : ...
7682
7783
7884@dataclass
@@ -81,15 +87,15 @@ class Runner:
8187 num_replicas : int
8288 lighthouse_address : str
8389 failure_injector : FailureInjector
84- train_loop : TrainLoop
90+ train_loop : TrainLoop [ object ]
8591
8692 use_cuda : bool = False
8793 world_size : int = 1
8894 attempts : int = 3
8995 manager_args : Dict [str , object ] = field (default_factory = dict )
9096 train_loop_args : Dict [str , Any ] = field (default_factory = dict )
9197
92- def _replica_main (self ) -> List [Dict [ str , Dict [ str , object ]] ]:
98+ def _replica_main (self ) -> List [object ]:
9399 store = dist .TCPStore (
94100 host_name = "localhost" ,
95101 port = 0 ,
@@ -131,7 +137,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
131137
132138 return [fut .result () for fut in futures ]
133139
134- def run_replica (self ) -> List [Dict [ str , Dict [ str , object ]] ]:
140+ def run_replica (self ) -> List [object ]:
135141 for i in range (self .attempts ):
136142 try :
137143 print (
@@ -391,3 +397,92 @@ def test_quorum_timeout(self) -> None:
391397 "status: Cancelled, message.*Timeout expired" ,
392398 ):
393399 manager .should_commit (timeout = timedelta (seconds = 0.01 ))
400+
401+ @parameterized .expand (
402+ [
403+ (True ,), # Test with CUDA
404+ (False ,), # Test without CUDA (CPU)
405+ ]
406+ )
407+ def test_manager_allreduce (self , use_cuda : bool ) -> None :
408+ # Skip the test if use_cuda is True and there are not enough GPUs
409+ if use_cuda and torch .cuda .device_count () < 2 :
410+ self .skipTest ("Not enough GPUs for CUDA test" )
411+
412+ # manager supports allreduce but we found an issue where the future callback is getting called
413+ # before the allreduce is complete. This test is to ensure that the callback has stream synchronization
414+ lighthouse = LighthouseServer (
415+ bind = "[::]:0" ,
416+ min_replicas = 2 ,
417+ )
418+ num_replicas = 2
419+ futures = []
420+
421+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
422+ for replica_id in range (num_replicas ):
423+ failure_injector = FailureInjector ()
424+ runner = Runner (
425+ replica_id = replica_id ,
426+ num_replicas = num_replicas ,
427+ lighthouse_address = lighthouse .address (),
428+ failure_injector = failure_injector ,
429+ train_loop = all_reduce_callback ,
430+ use_cuda = use_cuda ,
431+ )
432+ futures .append (executor .submit (runner .run_replica ))
433+
434+ results = []
435+ for fut in as_completed (futures ):
436+ try :
437+ results .append (fut .result ()[0 ])
438+ except Exception as e :
439+ print (e , flush = True )
440+ traceback .print_exc ()
441+ raise
442+
443+ lighthouse .shutdown ()
444+
445+ print (results )
446+ r0 , r1 = results
447+ torch .testing .assert_close (r0 , r1 , check_device = False )
448+
449+
450+ def all_reduce_callback (
451+ rank : int ,
452+ store_port : int ,
453+ device : torch .device ,
454+ runner : Runner ,
455+ ) -> Optional [torch .Tensor ]:
456+ with ExitStack () as stack :
457+ print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
458+
459+ if device .type == "cuda" :
460+ pg = ProcessGroupBabyNCCL ()
461+ else :
462+ pg = ProcessGroupGloo ()
463+ manager = Manager (
464+ pg = pg ,
465+ min_replica_size = 2 ,
466+ use_async_quorum = False ,
467+ load_state_dict = lambda x : None ,
468+ state_dict = lambda : None ,
469+ replica_id = str (runner .replica_id ),
470+ store_addr = "localhost" ,
471+ store_port = store_port ,
472+ rank = rank ,
473+ world_size = runner .world_size ,
474+ lighthouse_addr = runner .lighthouse_address ,
475+ port = 19530 + runner .replica_id ,
476+ timeout = timedelta (seconds = 10 ),
477+ quorum_timeout = timedelta (seconds = 10 ),
478+ # pyre-fixme[6]: Incompatible parameter type
479+ ** runner .manager_args ,
480+ )
481+ stack .callback (lambda : manager .shutdown (wait = False ))
482+
483+ manager .start_quorum ()
484+ t1 = torch .ones ((1 , 3 ), device = device )
485+ fut = manager .allreduce (t1 )
486+ fut .wait ()
487+ return t1
488+ return None
0 commit comments