2020import threading
2121from abc import ABC
2222from datetime import timedelta
23- from typing import TYPE_CHECKING , Dict , List , Optional , Type
23+ from typing import Dict , List , Optional , Tuple , Type , TYPE_CHECKING , Union
2424
2525import torch
2626import torch .distributed as dist
3131from torch .distributed import (
3232 BroadcastOptions ,
3333 DeviceMesh ,
34+ get_rank ,
35+ init_device_mesh ,
3436 PrefixStore ,
3537 ProcessGroup as BaseProcessGroup ,
3638 ProcessGroupGloo as BaseProcessGroupGloo ,
3739 ProcessGroupNCCL as BaseProcessGroupNCCL ,
3840 Store ,
3941 TCPStore ,
40- get_rank ,
4142)
42- from torch .distributed .distributed_c10d import Work , _world
43+ from torch .distributed .distributed_c10d import _world , Work
4344from torch .futures import Future
4445
4546if TYPE_CHECKING :
@@ -130,17 +131,7 @@ def size(self) -> int:
130131 def getBackendName (self ) -> str :
131132 raise NotImplementedError ("not implemented" )
132133
133- def register (self , name : str ) -> "ProcessGroup" :
134- """
135- Registers the process group with the global registry. This enables usage
136- with things like functional_collectives which are compilable.
137-
138- This should only be called once.
139-
140- Args:
141- name: name must be a unique name for this process group
142- """
143-
134+ def _register (self , name : str ) -> str :
144135 group_name = f"{ self .getBackendName ()} :{ name } "
145136
146137 # This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158149 devices = ["cpu" ]
159150 dist .Backend .register_backend (group_name , create_pg , devices = devices )
160151
152+ return group_name
153+
154+ def register (self , name : str ) -> "ProcessGroup" :
155+ """
156+ Registers the process group with the global registry. This enables usage
157+ with things like functional_collectives which are compilable.
158+
159+ This should only be called once.
160+
161+ Args:
162+ name: name must be a unique name for this process group
163+ """
164+
165+ group_name = self ._register (name )
166+
161167 return dist .new_group (
162168 ranks = [dist .get_rank ()],
163169 backend = group_name ,
@@ -244,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244250 This is a reconfigurable version of ProcessGroupGloo.
245251 """
246252
247- PG_CLASS : Type [BaseProcessGroup ] = (
248- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249- )
253+ PG_CLASS : Type [
254+ BaseProcessGroup
255+ ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
250256
251257 def getBackendName (self ) -> str :
252258 return "torchft-gloo"
@@ -263,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263269 abort when reconfiguring, we need to ensure this is safe.
264270 """
265271
266- PG_CLASS : Type [BaseProcessGroup ] = (
267- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268- )
272+ PG_CLASS : Type [
273+ BaseProcessGroup
274+ ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
269275
270276 def getBackendName (self ) -> str :
271277 return "torchft-nccl"
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496502 def size (self ) -> int :
497503 return self ._manager .num_participants ()
498504
505+ def getBackendName (self ) -> str :
506+ return self ._manager ._pg .getBackendName ()
507+
499508
500509class _BabyWork (Work ):
501510 def __init__ (
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689698 logger .exception (f"got unexpected error in future handler: { e } " )
690699
691700 def _get_future (self , op_id : int ) -> Future [object ]:
692-
693701 with self ._futures_lock :
694702 fut = Future () # pyre-fixme[29]: is not a function
695703 self ._futures [op_id ] = fut
@@ -737,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737745 ProcessGroupBabyNCCL.
738746 """
739747
740- PG_CLASS : Type [BaseProcessGroup ] = (
741- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742- )
748+ PG_CLASS : Type [
749+ BaseProcessGroup
750+ ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
743751
744752 def getBackendName (self ) -> str :
745753 return "torchft-baby-gloo"
@@ -761,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761769 tensors may leak in the current PyTorch implementation. TODO fix
762770 """
763771
764- PG_CLASS : Type [BaseProcessGroup ] = (
765- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766- )
772+ PG_CLASS : Type [
773+ BaseProcessGroup
774+ ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
767775 WORK_CLASS = _BabyWorkNCCL
768776
769777 def getBackendName (self ) -> str :
@@ -797,3 +805,184 @@ def extend_device_mesh(
797805 mesh = mesh .mesh .unsqueeze (dim ),
798806 mesh_dim_names = tuple (mesh_dim_names ),
799807 )
808+
809+
810+ class ManagedDeviceMesh (DeviceMesh ):
811+ def __init__ (
812+ self ,
813+ mesh : Optional [DeviceMesh ],
814+ mesh_dim_names : Tuple [str ],
815+ replicate_pg : ManagedProcessGroup ,
816+ replicate_dim : int ,
817+ parent : Optional ["ManagedDeviceMesh" ],
818+ ):
819+ self .mesh = mesh
820+ self .mesh_dim_names = mesh_dim_names
821+ self .replicate_pg = replicate_pg
822+ self .replicate_dim = replicate_dim
823+ self .replicate_dim_name = mesh_dim_names [replicate_dim ]
824+ self .parent = parent
825+ self .flatten_meshes = {}
826+
827+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
828+ if isinstance (mesh_dim_names , str ):
829+ if mesh_dim_names == self .replicate_dim_name :
830+ return ManagedDeviceMesh (
831+ mesh = None ,
832+ mesh_dim_names = (mesh_dim_names ,),
833+ replicate_pg = self .replicate_pg ,
834+ replicate_dim = 0 ,
835+ parent = self ,
836+ )
837+ elif mesh_dim_names in self .flatten_meshes :
838+ return self .flatten_meshes [mesh_dim_names ]
839+ else :
840+ return self .mesh [mesh_dim_names ]
841+ else :
842+ assert isinstance (mesh_dim_names , tuple )
843+ if self .replicate_dim_name in mesh_dim_names :
844+ return self .mesh [mesh_dim_names ]
845+ else :
846+ return ManagedDeviceMesh (
847+ self .mesh [mesh_dim_names ],
848+ mesh_dim_names ,
849+ self .replicate_pg ,
850+ mesh_dim_name .index (self .replicate_dim_name ),
851+ parent = self ,
852+ )
853+
854+ def get_group (self , mesh_dim : Optional [str ] = None ) -> BaseProcessGroup :
855+ if mesh_dim is None :
856+ assert self .mesh is None
857+ return self .replicate_pg
858+ elif mesh_dim == self .replicate_dim_name :
859+ return self .replicate_pg
860+ else :
861+ return self .mesh .get_group (mesh_dim )
862+
863+ def _flatten (self , mesh_dim_name : str ) -> "DeviceMesh" :
864+ flatten_mesh = _FlattenDeviceMesh (self )
865+ if self .parent is None :
866+ self .flatten_meshes [mesh_dim_name ] = flatten_mesh
867+ else :
868+ self .parent .flatten_meshes [mesh_dim_name ] = flatten_mesh
869+ return flatten_mesh
870+
871+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
872+ if mesh_dim is None :
873+ if self .mesh is None :
874+ return self .replicate_pg .size ()
875+ else :
876+ return self .mesh .size () * self .replicate_pg .size ()
877+ elif mesh_dim == self .replicate_dim :
878+ return self .replicate_pg .size ()
879+ else :
880+ return self .mesh .size (mesh_dim )
881+
882+ @property
883+ def ndim (self ) -> int :
884+ return self .mesh .ndim + 1
885+
886+ @property
887+ def shape (self ) -> Tuple [int , ...]:
888+ ret = list (self .mesh .shape )
889+ ret .insert (self .replicate_dim , self .replicate_pg .size ())
890+
891+ def get_rank (self ) -> int :
892+ return self .mesh .get_rank ()
893+
894+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
895+ if mesh_dim is None :
896+ if self .mesh is None :
897+ return get_rank (self .replicate_pg )
898+
899+ assert self .replicate_dim == 0 , "replicate_dim must be the first one"
900+ other_dim_size = self .mesh .size ()
901+ other_dim_rank = self .mesh .get_local_rank ()
902+ replicate_pg_rank = get_rank (self .replicate_pg )
903+ return other_dim_size * replicate_pg_rank + other_dim_rank
904+ elif mesh_dim in (self .replicate_dim , self .replicate_dim_name ):
905+ return get_rank (self .replicate_pg )
906+ else :
907+ return self .mesh .get_local_rank (mesh_dim )
908+
909+ def get_all_groups (self ) -> List [ProcessGroup ]:
910+ raise NotImplementedError
911+
912+
913+ class _FlattenDeviceMesh (DeviceMesh ):
914+ def __init__ (self , managed_mesh : ManagedDeviceMesh ):
915+ self .managed_mesh = managed_mesh
916+
917+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
918+ raise NotImplementedError
919+
920+ def get_group (self , mesh_dim : Optional [str ] = None ) -> BaseProcessGroup :
921+ raise NotImplementedError
922+
923+ def _flatten (self , mesh_dim_name : str ) -> "DeviceMesh" :
924+ raise NotImplementedError
925+
926+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
927+ assert mesh_dim is None
928+ return self .managed_mesh .size ()
929+
930+ @property
931+ def ndim (self ) -> int :
932+ raise NotImplementedError
933+
934+ @property
935+ def shape (self ) -> Tuple [int , ...]:
936+ raise NotImplementedError
937+
938+ def get_rank (self ) -> int :
939+ raise NotImplementedError
940+
941+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
942+ assert mesh_dim is None
943+ return self .managed_mesh .get_local_rank ()
944+
945+ def get_all_groups (self ) -> List [ProcessGroup ]:
946+ raise NotImplementedError
947+
948+
949+ def ft_init_device_mesh (
950+ * ,
951+ device_type : str ,
952+ mesh_shape : Tuple [int , ...],
953+ mesh_dim_names : Tuple [str , ...],
954+ replicate_dim : int ,
955+ manager : "Manager" ,
956+ ):
957+ # We have to lie DeviceMesh that the replicate_dim has only
958+ # 1 rank.
959+ _mesh_shape = list (mesh_shape )
960+ _mesh_shape .pop (replicate_dim )
961+ _mesh_dim_names = list (mesh_dim_names )
962+ _mesh_dim_names .pop (replicate_dim )
963+ mesh = init_device_mesh (
964+ device_type ,
965+ mesh_shape = tuple (_mesh_shape ),
966+ mesh_dim_names = tuple (_mesh_dim_names ),
967+ )
968+
969+ if device_type == "cpu" :
970+ pg = ProcessGroupGloo ()
971+ elif device_type == "cuda" :
972+ pg = ProcessGroupNCCL ()
973+ else :
974+ raise ValueError ()
975+
976+ manager ._pg = pg
977+ replicate_pg = ManagedProcessGroup (manager )
978+ # We have to use MultiProcessTestCase, otherwise c10d will complain
979+ # the same backend has been registered.
980+ replicate_pg .register (mesh_dim_names [replicate_dim ])
981+
982+ return ManagedDeviceMesh (
983+ mesh = mesh ,
984+ mesh_dim_names = mesh_dim_names ,
985+ replicate_pg = replicate_pg ,
986+ replicate_dim = replicate_dim ,
987+ parent = None ,
988+ )
0 commit comments