1919import  logging 
2020import  queue 
2121import  threading 
22+ from  collections  import  defaultdict 
23+ from  contextlib  import  contextmanager , nullcontext 
2224from  dataclasses  import  dataclass 
2325from  datetime  import  timedelta 
2426from  typing  import  (
2527    TYPE_CHECKING ,
2628    Any ,
2729    Callable ,
2830    Dict ,
31+     Generator ,
2932    List ,
3033    Optional ,
3134    Tuple ,
@@ -586,29 +589,59 @@ def __init__(
586589        self ._timeout  =  timeout 
587590
588591    def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
592+         self ._pg ._assert_alive ()
593+ 
589594        self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590-         assert  _get (self ._rx , self ._timeout ) ==  self ._op_id 
595+         op_id , event  =  cast (
596+             Tuple [int , Optional [torch .cuda .Event ]],
597+             _get (self ._rx , timeout  or  self ._timeout ),
598+         )
599+         assert  op_id  ==  self ._op_id 
600+         if  event  is  not None :
601+             event .wait ()
591602        return  True 
592603
604+     def  synchronize (self ) ->  None :
605+         # TODO: No one seems to use this and NCCL wait already only waits the 
606+         # stream and is non-blocking on the CPU side so no real need for a 
607+         # separate call. 
608+         raise  NotImplementedError ("not implemented" )
609+ 
593610    def  get_future (self ) ->  Future [object ]:
594611        return  self ._pg ._get_future (self ._op_id )
595612
596613    def  __del__ (self ) ->  None :
597614        self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598615
599616
600- class  _BabyWorkNCCL (_BabyWork ):
601-     def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
602-         self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603-         # pyre-fixme[23]: unable to unpack into 2 values 
604-         op_id , event  =  _get (self ._rx , self ._timeout )
605-         assert  op_id  ==  self ._op_id 
606-         assert  isinstance (event , torch .cuda .Event )
617+ def  _is_any_cuda (obj : object ) ->  bool :
618+     """ 
619+     Returns true if any of the tensors in the object are CUDA tensors. 
607620
608-         # Wait on Event makes the stream wait but not the CPU thread. 
609-         event .wait ()
621+     Supports lists, tuples, dicts, and tensors. 
622+     """ 
623+     if  isinstance (obj , torch .Tensor ):
624+         return  obj .is_cuda 
625+     elif  isinstance (obj , (list , tuple )):
626+         return  any (_is_any_cuda (o ) for  o  in  obj )
627+     elif  isinstance (obj , dict ):
628+         return  any (_is_any_cuda (o ) for  o  in  obj .values ())
629+     else :
630+         return  False 
610631
611-         return  True 
632+ 
633+ @dataclass  
634+ class  _OpMetadata :
635+     work : Work 
636+     stream : Optional [torch .cuda .Stream ]
637+ 
638+     @contextmanager  
639+     def  set_stream (self ) ->  Generator [None , None , None ]:
640+         if  self .stream  is  not None :
641+             with  torch .cuda .stream (self .stream ):
642+                 yield 
643+         else :
644+             yield 
612645
613646
614647class  ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
617650    subprocess. Since it's running in a subprocess all tensors need to be in 
618651    shared memory or will be moved to shared memory. CUDA tensors are implicitly 
619652    share able and don't need any changes. 
620- 
621653    """ 
622654
623-     WORK_CLASS : Type [_BabyWork ] =  _BabyWork 
624- 
625655    def  __init__ (self , timeout : Union [float , timedelta ] =  60.0 ) ->  None :
626656        super ().__init__ (0 , 1 )
627657
@@ -679,7 +709,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679709
680710        self ._p  =  ctx .Process (
681711            target = self ._worker ,
682-             args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
712+             args = (
713+                 store_addr ,
714+                 rank ,
715+                 world_size ,
716+                 self ._tx ,
717+                 self ._rx ,
718+                 self ._future_queue ,
719+             ),
683720            daemon = True ,
684721        )
685722        self ._p .start ()
@@ -716,23 +753,76 @@ def _worker(
716753                return 
717754            tx .put (None )
718755
719-             work  =  {}
756+             streams : Dict [str , torch .cuda .Stream ] =  {}
757+             work : Dict [int , _OpMetadata ] =  {}
720758            next_op_id : int  =  0 
721759
722760            while  True :
723761                op  =  rx .get ()
724762                cmd  =  op [0 ]
725763                if  cmd  ==  "func" :
726-                     func_name , args , kwargs  =  op [1 :]
727-                     args  =  _PickleSafeOptions .unsafe_args (args )
728-                     fn  =  getattr (pg , func_name )
729-                     work [next_op_id ] =  fn (* args , ** kwargs )
764+                     func_name , args , kwargs , stream_device , stream_id , event  =  op [1 :]
765+ 
766+                     print (f"func { func_name = }  )
767+ 
768+                     # To avoid potential deadlocks we need to preserve the 
769+                     # stream/synchronization behavior of the parent process. 
770+                     # We allocate one Stream per stream_id to make sure that we 
771+                     # don't accidentally introduce cross stream synchronization 
772+                     # points. 
773+                     if  stream_id  is  not None :
774+                         stream_key  =  f"{ stream_device } { stream_id }  
775+                         if  stream_key  not  in streams :
776+                             streams [stream_key ] =  torch .cuda .Stream (
777+                                 device = stream_device 
778+                             )
779+                         stream  =  streams [stream_key ]
780+                     else :
781+                         stream  =  None 
782+ 
783+                     with  (
784+                         torch .cuda .stream (stream )
785+                         if  stream  is  not None 
786+                         else  nullcontext ()
787+                     ):
788+                         print ("stream created" )
789+ 
790+                         # Make the stream wait on the cuda event to make sure we 
791+                         # don't start the operation until the tensor is ready. 
792+                         if  event  is  not None :
793+                             event .wait ()
794+ 
795+                         print ("waited" )
796+ 
797+                         args  =  _PickleSafeOptions .unsafe_args (args )
798+                         fn  =  getattr (pg , func_name )
799+                         work [next_op_id ] =  _OpMetadata (
800+                             work = fn (* args , ** kwargs ),
801+                             stream = stream ,
802+                         )
730803                    tx .put (next_op_id )
731804                    next_op_id  +=  1 
732805                elif  cmd  ==  "wait" :
733806                    op_id : int  =  op [1 ]
734-                     work [op_id ].wait ()
735-                     tx .put (op_id )
807+ 
808+                     metadata  =  work [op_id ]
809+ 
810+                     with  metadata .set_stream ():
811+                         # With WorkNCCL this makes the stream wait not the CPU when 
812+                         # no timeout is passed. 
813+                         metadata .work .wait ()
814+ 
815+                         # Register event on the stream that we can pass to the main 
816+                         # process. 
817+                         event  =  (
818+                             torch .cuda .current_stream ().record_event (
819+                                 torch .cuda .Event (interprocess = True )
820+                             )
821+                             if  metadata .stream  is  not None 
822+                             else  None 
823+                         )
824+ 
825+                     tx .put ((op_id , event ))
736826                elif  cmd  ==  "del" :
737827                    op_id : int  =  op [1 ]
738828                    del  work [op_id ]
@@ -746,23 +836,8 @@ def callback(fut: Future[object]) -> None:
746836                        except  Exception  as  e :
747837                            future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748838
749-                     work [op_id ].get_future ().add_done_callback (callback )
839+                     work [op_id ].work . get_future ().add_done_callback (callback )
750840                    tx .put (op_id )
751-                 elif  cmd  ==  "synchronize" :
752-                     # CUDA only, use events instead of waiting on CPU 
753-                     op_id  =  op [1 ]
754- 
755-                     # With WorkNCCL this makes the stream wait not the CPU when 
756-                     # no timeout is passed. 
757-                     work [op_id ].wait ()
758- 
759-                     # Register event on the stream that we can pass to the main 
760-                     # process. 
761-                     event  =  torch .cuda .Event (interprocess = True )
762-                     event .record ()
763- 
764-                     del  work [op_id ]
765-                     tx .put ((op_id , event ))
766841                elif  cmd  ==  "num_active_work" :
767842                    tx .put (len (work ))
768843                else :
@@ -792,6 +867,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792867            logger .exception (f"got unexpected error in future handler: { e }  )
793868
794869    def  _get_future (self , op_id : int ) ->  Future [object ]:
870+         self ._assert_alive ()
871+ 
795872        with  self ._futures_lock :
796873            fut  =  Future ()  # pyre-fixme[29]: is not a function 
797874            self ._futures [op_id ] =  fut 
@@ -804,22 +881,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804881        return  fut 
805882
806883    def  _run_func (self , func : str , * args : object , ** kwargs : object ) ->  Work :
884+         self ._assert_alive ()
885+ 
807886        rx  =  self ._rx 
808887        tx  =  self ._tx 
809888        assert  rx  is  not None 
810889        assert  tx  is  not None 
811890
891+         is_cuda  =  _is_any_cuda (args )
892+ 
893+         stream_device  =  torch .cuda .current_stream ().device  if  is_cuda  else  None 
894+         stream_id  =  torch .cuda .current_stream ().stream_id  if  is_cuda  else  None 
895+         event  =  (
896+             torch .cuda .current_stream ().record_event (
897+                 torch .cuda .Event (interprocess = True )
898+             )
899+             if  is_cuda 
900+             else  None 
901+         )
902+ 
812903        tx .put (
813-             ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
904+             (
905+                 "func" ,
906+                 func ,
907+                 _PickleSafeOptions .safe_args (args ),
908+                 kwargs ,
909+                 stream_device ,
910+                 stream_id ,
911+                 event ,
912+             ),
814913            timeout = self ._timeout ,
815914        )
816915
817916        op_id  =  _get (rx , self ._timeout )
818917        assert  isinstance (op_id , int ), f"invalid return { op_id }  
819918
820-         return  self .WORK_CLASS (
821-             pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout 
822-         )
919+         return  _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
920+ 
921+     def  _assert_alive (self ) ->  None :
922+         """ 
923+         Assert that the process group is alive. This is used to ensure that 
924+         operations are not performed on a dead process group and any errors are surfaced. 
925+         """ 
926+         p  =  self ._p 
927+         assert  p  is  not None 
928+         if  not  p .is_alive ():
929+             raise  RuntimeError (f"child process { p .pid = } { p .exitcode = }  )
823930
824931    def  allreduce (
825932        self ,
@@ -952,8 +1059,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521059    tensors may leak in the current PyTorch implementation. TODO fix 
9531060    """ 
9541061
955-     WORK_CLASS  =  _BabyWorkNCCL 
956- 
9571062    @classmethod  
9581063    def  _create_pg (cls , store : Store , rank : int , world_size : int ) ->  BaseProcessGroup :
9591064        # pyre-fixme[16]: no attribute ProcessGroupNCCL 
0 commit comments