1919import  logging 
2020import  queue 
2121import  threading 
22+ from  contextlib  import  contextmanager , nullcontext 
2223from  dataclasses  import  dataclass 
2324from  datetime  import  timedelta 
2425from  typing  import  (
2526    TYPE_CHECKING ,
2627    Any ,
2728    Callable ,
2829    Dict ,
30+     Generator ,
2931    List ,
3032    Optional ,
3133    Tuple ,
32-     Type ,
3334    TypeVar ,
3435    Union ,
3536    cast ,
5859    BroadcastOptions ,
5960    ReduceOp ,
6061    Work ,
61-     _world ,
6262)
6363from  torch .futures  import  Future 
64+ from  torch .utils ._pytree  import  tree_any 
6465
6566if  TYPE_CHECKING :
6667    from  torchft .manager  import  Manager 
@@ -586,29 +587,52 @@ def __init__(
586587        self ._timeout  =  timeout 
587588
588589    def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
590+         self ._pg ._assert_alive ()
591+ 
589592        self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590-         assert  _get (self ._rx , self ._timeout ) ==  self ._op_id 
593+         op_id , event  =  cast (
594+             Tuple [int , Optional [torch .cuda .Event ]],
595+             _get (self ._rx , timeout  or  self ._timeout ),
596+         )
597+         assert  op_id  ==  self ._op_id 
598+         if  event  is  not None :
599+             event .wait ()
591600        return  True 
592601
602+     def  synchronize (self ) ->  None :
603+         # TODO: No one seems to use this and NCCL wait already only waits the 
604+         # stream and is non-blocking on the CPU side so no real need for a 
605+         # separate call. 
606+         raise  NotImplementedError ("not implemented" )
607+ 
593608    def  get_future (self ) ->  Future [object ]:
594609        return  self ._pg ._get_future (self ._op_id )
595610
596611    def  __del__ (self ) ->  None :
597612        self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598613
599614
600- class  _BabyWorkNCCL (_BabyWork ):
601-     def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
602-         self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603-         # pyre-fixme[23]: unable to unpack into 2 values 
604-         op_id , event  =  _get (self ._rx , self ._timeout )
605-         assert  op_id  ==  self ._op_id 
606-         assert  isinstance (event , torch .cuda .Event )
615+ def  _is_any_cuda (obj : object ) ->  bool :
616+     """ 
617+     Returns true if any of the tensors in the object are CUDA tensors. 
607618
608-         # Wait on Event makes the stream wait but not the CPU thread. 
609-         event .wait ()
619+     Supports lists, tuples, dicts, and tensors. 
620+     """ 
621+     return  tree_any (lambda  obj : isinstance (obj , torch .Tensor ) and  obj .is_cuda , obj )
610622
611-         return  True 
623+ 
624+ @dataclass  
625+ class  _OpMetadata :
626+     work : Work 
627+     stream : Optional [torch .cuda .Stream ]
628+ 
629+     @contextmanager  
630+     def  set_stream (self ) ->  Generator [None , None , None ]:
631+         if  self .stream  is  not None :
632+             with  torch .cuda .stream (self .stream ):
633+                 yield 
634+         else :
635+             yield 
612636
613637
614638class  ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
617641    subprocess. Since it's running in a subprocess all tensors need to be in 
618642    shared memory or will be moved to shared memory. CUDA tensors are implicitly 
619643    share able and don't need any changes. 
620- 
621644    """ 
622645
623-     WORK_CLASS : Type [_BabyWork ] =  _BabyWork 
624- 
625646    def  __init__ (self , timeout : Union [float , timedelta ] =  60.0 ) ->  None :
626647        super ().__init__ (0 , 1 )
627648
@@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679700
680701        self ._p  =  ctx .Process (
681702            target = self ._worker ,
682-             args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
703+             args = (
704+                 store_addr ,
705+                 rank ,
706+                 world_size ,
707+                 self ._tx ,
708+                 self ._rx ,
709+                 self ._future_queue ,
710+             ),
683711            daemon = True ,
684712        )
685713        self ._p .start ()
@@ -716,23 +744,70 @@ def _worker(
716744                return 
717745            tx .put (None )
718746
719-             work  =  {}
747+             streams : Dict [str , torch .cuda .Stream ] =  {}
748+             work : Dict [int , _OpMetadata ] =  {}
720749            next_op_id : int  =  0 
721750
722751            while  True :
723752                op  =  rx .get ()
724753                cmd  =  op [0 ]
725754                if  cmd  ==  "func" :
726-                     func_name , args , kwargs  =  op [1 :]
727-                     args  =  _PickleSafeOptions .unsafe_args (args )
728-                     fn  =  getattr (pg , func_name )
729-                     work [next_op_id ] =  fn (* args , ** kwargs )
755+                     func_name , args , kwargs , stream_device , stream_id , event  =  op [1 :]
756+ 
757+                     # To avoid potential deadlocks we need to preserve the 
758+                     # stream/synchronization behavior of the parent process. 
759+                     # We allocate one Stream per stream_id to make sure that we 
760+                     # don't accidentally introduce cross stream synchronization 
761+                     # points. 
762+                     if  stream_id  is  not None :
763+                         stream_key  =  f"{ stream_device } { stream_id }  
764+                         if  stream_key  not  in streams :
765+                             streams [stream_key ] =  torch .cuda .Stream (
766+                                 device = stream_device 
767+                             )
768+                         stream  =  streams [stream_key ]
769+                     else :
770+                         stream  =  None 
771+ 
772+                     with  (
773+                         torch .cuda .stream (stream )
774+                         if  stream  is  not None 
775+                         else  nullcontext ()
776+                     ):
777+                         # Make the stream wait on the cuda event to make sure we 
778+                         # don't start the operation until the tensor is ready. 
779+                         if  event  is  not None :
780+                             event .wait ()
781+ 
782+                         args  =  _PickleSafeOptions .unsafe_args (args )
783+                         fn  =  getattr (pg , func_name )
784+                         work [next_op_id ] =  _OpMetadata (
785+                             work = fn (* args , ** kwargs ),
786+                             stream = stream ,
787+                         )
730788                    tx .put (next_op_id )
731789                    next_op_id  +=  1 
732790                elif  cmd  ==  "wait" :
733791                    op_id : int  =  op [1 ]
734-                     work [op_id ].wait ()
735-                     tx .put (op_id )
792+ 
793+                     metadata  =  work [op_id ]
794+ 
795+                     with  metadata .set_stream ():
796+                         # With WorkNCCL this makes the stream wait not the CPU when 
797+                         # no timeout is passed. 
798+                         metadata .work .wait ()
799+ 
800+                         # Register event on the stream that we can pass to the main 
801+                         # process. 
802+                         event  =  (
803+                             torch .cuda .current_stream ().record_event (
804+                                 torch .cuda .Event (interprocess = True )
805+                             )
806+                             if  metadata .stream  is  not None 
807+                             else  None 
808+                         )
809+ 
810+                     tx .put ((op_id , event ))
736811                elif  cmd  ==  "del" :
737812                    op_id : int  =  op [1 ]
738813                    del  work [op_id ]
@@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None:
746821                        except  Exception  as  e :
747822                            future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748823
749-                     work [op_id ].get_future ().add_done_callback (callback )
824+                     work [op_id ].work . get_future ().add_done_callback (callback )
750825                    tx .put (op_id )
751-                 elif  cmd  ==  "synchronize" :
752-                     # CUDA only, use events instead of waiting on CPU 
753-                     op_id  =  op [1 ]
754- 
755-                     # With WorkNCCL this makes the stream wait not the CPU when 
756-                     # no timeout is passed. 
757-                     work [op_id ].wait ()
758- 
759-                     # Register event on the stream that we can pass to the main 
760-                     # process. 
761-                     event  =  torch .cuda .Event (interprocess = True )
762-                     event .record ()
763- 
764-                     del  work [op_id ]
765-                     tx .put ((op_id , event ))
766826                elif  cmd  ==  "num_active_work" :
767827                    tx .put (len (work ))
768828                else :
@@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None:
771831        except  Exception  as  e :
772832            logger .exception ("worker errored" )
773833            tx .put (e )
834+             raise 
774835
775836    def  _future_handler (self , future_queue : mp .Queue ) ->  None :
776837        try :
@@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792853            logger .exception (f"got unexpected error in future handler: { e }  )
793854
794855    def  _get_future (self , op_id : int ) ->  Future [object ]:
856+         self ._assert_alive ()
857+ 
795858        with  self ._futures_lock :
796859            fut  =  Future ()  # pyre-fixme[29]: is not a function 
797860            self ._futures [op_id ] =  fut 
@@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804867        return  fut 
805868
806869    def  _run_func (self , func : str , * args : object , ** kwargs : object ) ->  Work :
870+         self ._assert_alive ()
871+ 
807872        rx  =  self ._rx 
808873        tx  =  self ._tx 
809874        assert  rx  is  not None 
810875        assert  tx  is  not None 
811876
877+         is_cuda  =  _is_any_cuda (args )
878+ 
879+         stream_device  =  torch .cuda .current_stream ().device  if  is_cuda  else  None 
880+         stream_id  =  torch .cuda .current_stream ().stream_id  if  is_cuda  else  None 
881+         event  =  (
882+             torch .cuda .current_stream ().record_event (
883+                 torch .cuda .Event (interprocess = True )
884+             )
885+             if  is_cuda 
886+             else  None 
887+         )
888+ 
812889        tx .put (
813-             ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
890+             (
891+                 "func" ,
892+                 func ,
893+                 _PickleSafeOptions .safe_args (args ),
894+                 kwargs ,
895+                 stream_device ,
896+                 stream_id ,
897+                 event ,
898+             ),
814899            timeout = self ._timeout ,
815900        )
816901
817902        op_id  =  _get (rx , self ._timeout )
818903        assert  isinstance (op_id , int ), f"invalid return { op_id }  
819904
820-         return  self .WORK_CLASS (
821-             pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout 
822-         )
905+         return  _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
906+ 
907+     def  _assert_alive (self ) ->  None :
908+         """ 
909+         Assert that the process group is alive. This is used to ensure that 
910+         operations are not performed on a dead process group and any errors are surfaced. 
911+         """ 
912+         p  =  self ._p 
913+         assert  p  is  not None 
914+         if  not  p .is_alive ():
915+             raise  RuntimeError (f"child process { p .pid = } { p .exitcode = }  )
823916
824917    def  allreduce (
825918        self ,
@@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521045    tensors may leak in the current PyTorch implementation. TODO fix 
9531046    """ 
9541047
955-     WORK_CLASS  =  _BabyWorkNCCL 
956- 
9571048    @classmethod  
9581049    def  _create_pg (cls , store : Store , rank : int , world_size : int ) ->  BaseProcessGroup :
9591050        # pyre-fixme[16]: no attribute ProcessGroupNCCL 
0 commit comments