44from contextlib import contextmanager
55from dataclasses import dataclass
66from datetime import timedelta
7- from typing import Generator , List , Tuple , TypeVar , Union , cast
7+ from typing import Callable , Generator , Optional , TypeVar , Union , cast
88
99import torch
1010from torch .distributed import Work
1111from torch .distributed .tensor import DTensor , _DTensorSpec
12- from torch .utils ._pytree import TreeSpec , tree_flatten , tree_unflatten
12+ from torch .utils ._pytree import (
13+ KeyPath ,
14+ TreeSpec ,
15+ tree_flatten_with_path ,
16+ tree_unflatten ,
17+ )
1318
1419from torchft .checkpointing .transport import CheckpointTransport
1520from torchft .process_group import ProcessGroup
@@ -32,7 +37,7 @@ class _TensorMeta:
3237 shape : torch .Size
3338 dtype : torch .dtype
3439 storage_offset : int
35- stride : Tuple [int , ...]
40+ stride : tuple [int , ...]
3641 nbytes : int
3742
3843
@@ -61,13 +66,15 @@ class _StateDictMeta:
6166 Args:
6267 step: the step of the checkpoint to verify consistency
6368 treespec: the pytree spec of the state dict
69+ paths: the path of each leaf in the state dict
6470 non_tensor_leaves: the metadata for each tensor in the state dict and any
6571 non-tensor leaves in the state dict
6672 """
6773
6874 step : int
6975 treespec : TreeSpec
70- non_tensor_leaves : List [Union [object , _TensorMeta , _DTensorMeta ]]
76+ paths : list [KeyPath ]
77+ non_tensor_leaves : list [Union [object , _TensorMeta , _DTensorMeta ]]
7178
7279
7380@contextmanager
@@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]:
7885 logger .info (f"{ name } took { dur } s" )
7986
8087
81- def _prepare_tensor (tensor : torch .Tensor ) -> Tuple [torch .Tensor , _TensorMeta ]:
88+ def _prepare_tensor (tensor : torch .Tensor ) -> tuple [torch .Tensor , _TensorMeta ]:
8289 return (
8390 _cast_tensor (tensor , torch .uint8 ),
8491 _TensorMeta (
@@ -95,12 +102,16 @@ def _prepare_state_dict(
95102 state_dict : object ,
96103 step : int ,
97104 device : torch .device ,
98- ) -> Tuple [_StateDictMeta , List [torch .Tensor ]]:
99- leaves , treespec = tree_flatten (state_dict )
105+ ) -> tuple [_StateDictMeta , list [torch .Tensor ]]:
106+ leaves : list [tuple [KeyPath , object ]]
107+ leaves , treespec = tree_flatten_with_path (state_dict )
108+
109+ paths : list [KeyPath ] = []
110+ non_tensor_leaves : list [Union [object , _TensorMeta , _DTensorMeta ]] = []
111+ tensors : list [torch .Tensor ] = []
112+ for key_path , v in leaves :
113+ paths .append (key_path )
100114
101- non_tensor_leaves = []
102- tensors = []
103- for v in leaves :
104115 if isinstance (v , DTensor ):
105116 tensor , tensor_meta = _prepare_tensor (v ._local_tensor )
106117
@@ -123,6 +134,7 @@ def _prepare_state_dict(
123134 _StateDictMeta (
124135 step = step ,
125136 treespec = treespec ,
137+ paths = paths ,
126138 non_tensor_leaves = non_tensor_leaves ,
127139 ),
128140 tensors ,
@@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
139151 caveat that the cast tensor may be larger than the original tensor due to
140152 the differences in striding.
141153 """
154+ assert (
155+ type (tensor ) is torch .Tensor
156+ ), f"can only cast standard tensors not { type (tensor )} "
142157 storage = tensor .untyped_storage ()
143158 ret = torch .tensor (storage , dtype = dtype , device = tensor .device )
144159 assert ret .untyped_storage () is storage , "storage should be the same"
@@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]):
150165 This is a checkpoint transport that uses the process group to transfer checkpoints.
151166 This allows for fast recovery of workers by fetching the current weights
152167 from an existing worker.
168+
153169 Args:
154- state_dict: a callable that returns the state dict to be transferred
170+ pg: the process group to use for communication
171+ timeout: the timeout for communication
172+ device: the device to use for tensors
173+ state_dict: if specified this function will be called to do an inplace
174+ receive into the returned state_dict. This is much faster than
175+ having to allocate new tensors and transferring them to the CPU.
155176 """
156177
157178 def __init__ (
158- self , pg : ProcessGroup , timeout : timedelta , device : torch .device
179+ self ,
180+ pg : ProcessGroup ,
181+ timeout : timedelta ,
182+ device : torch .device ,
183+ state_dict : Optional [Callable [[], object ]] = None ,
159184 ) -> None :
160- self ._work : List [Work ] = []
185+ self ._work : list [Work ] = []
161186 self ._pg = pg
162187 self ._timeout = timeout
163188 self ._device = device
189+ self ._state_dict = state_dict
164190
165191 def metadata (self ) -> str :
166192 return "<n/a>"
@@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None:
169195 pass
170196
171197 def send_checkpoint (
172- self , dst_ranks : List [int ], step : int , state_dict : T , timeout : timedelta
198+ self , dst_ranks : list [int ], step : int , state_dict : T , timeout : timedelta
173199 ) -> None :
174200 with _timeit ("preparing state_dict" ):
175201 meta , tensors = _prepare_state_dict (state_dict , step , device = self ._device )
@@ -186,20 +212,29 @@ def send_checkpoint(
186212
187213 with _timeit ("send tensors" ):
188214 for i , t in enumerate (tensors ):
215+ original_device = t .device
189216 t = t .to (self ._device )
190217 for dst_rank in dst_ranks :
191218 work .append (self ._pg .send ([t ], dst_rank , tag = 3 + i ))
192219
193- # allow 3 concurrent transfers at a time to avoid OOMs
194- while len (work ) > (3 * len (dst_ranks )):
195- work .pop (0 ).wait (timeout )
220+ # if we did a copy we should wait for the work to complete so we
221+ # can free the memory to avoid OOMs
222+ if original_device == torch .device ("cpu" ):
223+ for w in work :
224+ w .wait (timeout )
225+ work = []
196226
197227 for w in work :
198228 w .wait (timeout )
199229
200230 def recv_checkpoint (
201231 self , src_rank : int , metadata : str , step : int , timeout : timedelta
202232 ) -> T :
233+ state_dict = self ._state_dict () if self ._state_dict else {}
234+ state_dict_leaves , _ = tree_flatten_with_path (state_dict )
235+
236+ dst_tensors : dict [KeyPath , object ] = dict (state_dict_leaves )
237+
203238 len_t = torch .zeros (1 , dtype = torch .int64 , device = self ._device )
204239 self ._pg .recv ([len_t ], src_rank , tag = 1 ).wait (timeout )
205240 length = cast (int , len_t .item ())
@@ -213,18 +248,34 @@ def recv_checkpoint(
213248 assert meta .step == step
214249
215250 i : int = 0
251+ works : list [Work ] = []
216252
217- def recv (v : _TensorMeta ) -> torch .Tensor :
253+ def recv (path : KeyPath , v : _TensorMeta ) -> torch .Tensor :
218254 nonlocal i
219255
220- t = torch .empty (v .nbytes , dtype = torch .uint8 , device = self ._device )
221- # TODO: parallelize receives
222- self ._pg .recv ([t ], src_rank , tag = 3 + i ).wait (timeout )
256+ inplace = dst_tensors .get (path )
257+ if (
258+ isinstance (inplace , torch .Tensor )
259+ and inplace .device .type == self ._device .type
260+ ):
261+ if isinstance (inplace , DTensor ):
262+ inplace = inplace ._local_tensor
263+ t = _cast_tensor (inplace , torch .uint8 )
264+ assert (
265+ t .nbytes == v .nbytes
266+ ), "inplace tensor storage must be the same size"
267+ else :
268+ t = torch .empty (v .nbytes , dtype = torch .uint8 , device = self ._device )
269+
270+ work = self ._pg .recv ([t ], src_rank , tag = 3 + i )
223271 i += 1
224272
225- # TODO: allow in place receives to avoid having to copy to cpu to
226- # avoid OOMs
227- t = t .cpu ()
273+ if inplace is None :
274+ # if not inplace we need to copy it to CPU to avoid OOMing
275+ work .wait (timeout )
276+ t = t .cpu ()
277+ else :
278+ works .append (work )
228279
229280 return torch .as_strided (
230281 t .view (v .dtype ),
@@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor:
234285 )
235286
236287 values = []
237- for v in meta .non_tensor_leaves :
288+ for path , v in zip ( meta .paths , meta . non_tensor_leaves ) :
238289 if isinstance (v , _TensorMeta ):
239- values .append (recv (v ))
290+ values .append (recv (path , v ))
240291 elif isinstance (v , _DTensorMeta ):
241- tensor = recv (v .local )
292+ tensor = recv (path , v .local )
242293 # pyre-fixme[29]: DTensor is not a function
243294 values .append (DTensor (tensor , v .spec , requires_grad = False ))
244295 else :
245296 values .append (v )
246297
298+ for work in works :
299+ work .wait (timeout )
300+
247301 return tree_unflatten (values , meta .treespec )
0 commit comments