1313from  dataclasses  import  dataclass , field 
1414from  io  import  BytesIO 
1515from  multiprocessing  import  get_context 
16- from  typing  import  Any , Dict , List , Union 
16+ from  typing  import  Any , Dict , List , Optional ,  Union 
1717
1818import  torch 
1919import  torch .distributed  as  dist 
2020import  torch .distributed .checkpoint  as  dcp 
2121import  torch .nn  as  nn 
22+ from  torch .distributed ._state_dict_utils  import  _copy_state_dict , _create_cpu_state_dict 
2223from  torch .distributed .checkpoint .state_dict  import  (
2324    get_model_state_dict ,
2425    set_model_state_dict ,
@@ -143,50 +144,29 @@ def __init__(
143144        lr_schedulers : SchedulersContainer ,
144145        states : Dict [str , Any ],
145146        job_config : JobConfig ,
147+         ft_manager : Optional [Any ] =  None ,
146148    ) ->  None :
147149        ckpt_config  =  job_config .checkpoint 
148150        self .enable_checkpoint  =  ckpt_config .enable_checkpoint 
149-         self .keep_latest_k  =  ckpt_config .keep_latest_k 
151+         self .ft_manager  =  ft_manager 
152+         self .enable_staging  =  (
153+             self .enable_checkpoint  and  async_mode  ==  AsyncMode .ASYNC_WITH_PINNED_MEM 
154+         ) or  self .ft_manager 
150155
151-         if  not  self .enable_checkpoint :
156+         if  not  self .enable_checkpoint   and   self . ft_manager   is   None :
152157            return 
153-         """ 
154-         Note: Pipeline Parallelism and Virtual Stages 
155- 
156-         1. even for simple PP schedules, there is a separate optimizer each PP rank. 
157-         rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model. 
158-         rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1. 
159-         When saving, these collide and one of them is lost.  Then when reloading, only one stage can 
160-         restore its optimizer states, others will error. 
161- 
162-             The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan 
163-             by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. 
164- 
165-         2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also 
166-         requiring us to reason about multiple 'optim' objects locally. 
167- 
168-             We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object 
169-             into one state dict before saving/loading. We rely on the individual state_dicts to not collide, 
170-             which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening 
171-             support described in (1). 
172- 
173-         3. LR schedulers also index model states like optimizers and would need to be flattened properly to support 
174-         resharding.  Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like 
175-         optimizers do, so it's hard to write a generic 'flattener' utility. 
176- 
177-             TODO: This is currently unsolved and needs a fix. 
178-         """ 
179-         self .states  =  states 
180158
181-         self .states .update (
182-             {
183-                 "model" : ModelWrapper (model_parts ),
184-                 "optimizer" : optimizers ,
185-                 "dataloader" : dataloader ,
186-                 "lr_scheduler" : lr_schedulers ,
187-             }
159+         self ._initialize_states (
160+             states , dataloader , model_parts , optimizers , lr_schedulers 
188161        )
189162
163+         async_mode  =  ckpt_config .async_mode .lower ()
164+         self .staging  =  False 
165+         self .sending_to_checkpoint_mp  =  False 
166+         self .staging_id  =  None 
167+         self .cpu_offload_state_dict  =  None 
168+         self .staging_stream  =  torch .cuda .Stream () if  self .enable_staging  else  None 
169+ 
190170        self .folder  =  os .path .join (job_config .job .dump_folder , ckpt_config .folder )
191171        self .interval_type  =  (
192172            IntervalType .SECONDS 
@@ -201,6 +181,7 @@ def __init__(
201181        if  async_mode  ==  AsyncMode .ASYNC  or  self .interval_type  ==  IntervalType .SECONDS :
202182            self .pg  =  dist .new_group (backend = "gloo" )
203183
184+         self .keep_latest_k  =  ckpt_config .keep_latest_k 
204185        self .model_weights_only  =  ckpt_config .model_weights_only 
205186        self .export_dtype  =  TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
206187
@@ -224,10 +205,6 @@ def __init__(
224205                daemon = True ,
225206            )
226207            self .mp .start ()
227-             self .cpu_offload_state_dict  =  None 
228-             self .staging  =  False 
229-             self .staging_id  =  None 
230-             self .staging_stream  =  torch .cuda .Stream ()
231208        else :
232209            raise  ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode }  )
233210
@@ -241,8 +218,61 @@ def __del__(self):
241218            self .mp .join ()
242219
243220    def  reset (self ) ->  None :
221+         # We need to stage the local state if another replicate joins during the 
222+         # first step. 
223+         if  self .ft_manager :
224+             self .cpu_staging (None )
244225        self .begin_time  =  time .monotonic ()
245226
227+     def  _initialize_states (
228+         self ,
229+         states : Dict [str , Any ],
230+         dataloader : DataLoader ,
231+         model_parts : List [nn .Module ],
232+         optimizers : OptimizersContainer ,
233+         lr_schedulers : SchedulersContainer ,
234+     ) ->  None :
235+         """ 
236+         Note: Pipeline Parallelism and Virtual Stages 
237+ 
238+         1. Even for simple PP schedules, there is a separate optimizer each PP rank. 
239+         rank0's optimizer would have a param_group[0] which refers to layers.0 in the 
240+         original model. rank1's would _also_ have a param_group[0], since it's index based, 
241+         but referring to layers.1. 
242+         When saving, these collide and one of them is lost.  Then when reloading, only one 
243+         stage can restore its optimizer states, others will error. 
244+ 
245+             The solution to this problem is optimizer flattening: it landed in #127071 
246+             and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' 
247+             kwarg to DCP functions called in the OptimizerContainer. 
248+ 
249+         2. With complex PP schedules, we have multiple model chunks per pp rank. This 
250+         compounds challenge (1) by also requiring us to reason about multiple 'optim' 
251+         objects locally. 
252+ 
253+             We solve this in the Model and Optimizer wrapper classes by flattening the 
254+             state dicts from each object into one state dict before saving/loading. 
255+             We rely on the individual state_dicts to not collide, which is gauranteed for 
256+             the model by correct pipeline splitting and for the optimizer by the flattening 
257+             support described in (1). 
258+ 
259+         3. LR schedulers also index model states like optimizers and would need to be 
260+         flattened properly to support resharding. Unfortunately, the implementations of 
261+         different lr_schedulers do not follow a clear pattern like optimizers do, so it's 
262+         hard to write a generic 'flattener' utility. 
263+ 
264+             TODO: This is currently unsolved and needs a fix. 
265+         """ 
266+         self .states  =  states 
267+         self .states .update (
268+             {
269+                 "model" : ModelWrapper (model_parts ),
270+                 "optimizer" : optimizers ,
271+                 "dataloader" : dataloader ,
272+                 "lr_scheduler" : lr_schedulers ,
273+             }
274+         )
275+ 
246276    def  _create_checkpoint_id (self , step : int ) ->  str :
247277        return  os .path .join (self .folder , f"step-{ step }  )
248278
@@ -325,31 +355,8 @@ def _async_wait(self) -> None:
325355                self .async_future .result ()
326356
327357    def  _async_with_pinned_memory (self , checkpoint_id : str ) ->  None :
328-         try :
329-             from  torch .distributed ._state_dict_utils  import  (
330-                 _copy_state_dict ,
331-                 _create_cpu_state_dict ,
332-             )
333-         except  ImportError  as  e :
334-             raise  ImportError (
335-                 "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." 
336-             ) from  e 
337-         state_dict  =  dcp .state_dict_saver ._stateful_to_state_dict (self .states )
338-         if  self .cpu_offload_state_dict  is  None :
339-             logger .debug (f"Preparing the CPU memory, { time .monotonic ()= }  )
340-             self .cpu_offload_state_dict  =  _create_cpu_state_dict (
341-                 state_dict , pin_memory = True , share_memory = True 
342-             )
343- 
344-         logger .debug (f"Staging the state_dict, { time .monotonic ()= }  )
345-         with  torch .cuda .stream (self .staging_stream ):
346-             self .cpu_offload_state_dict  =  _copy_state_dict (
347-                 state_dict ,
348-                 self .cpu_offload_state_dict ,
349-                 non_blocking = True ,
350-             )
351-             self .staging  =  True 
352-             self .staging_id  =  checkpoint_id 
358+         self .cpu_staging (checkpoint_id )
359+         self .sending_to_checkpoint_mp  =  True 
353360
354361    def  save (self , curr_step : int , force : bool  =  False ) ->  None :
355362        """ 
@@ -359,6 +366,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
359366        for initial seed checkpoint. 
360367        """ 
361368        if  not  self ._should_save (curr_step , force ):
369+             if  self .ft_manager :
370+                 self .cpu_staging (None )
362371            return 
363372
364373        begin  =  time .monotonic ()
@@ -382,26 +391,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
382391            f"in { time .monotonic () -  begin :.2f}  
383392        )
384393
394+     def  cpu_staging (self , checkpoint_id : Optional [str ]) ->  None :
395+         """Offload state_dict to CPU memory""" 
396+         state_dict  =  dcp .state_dict_saver ._stateful_to_state_dict (self .states )
397+         if  self .cpu_offload_state_dict  is  None :
398+             logger .debug (f"Preparing the CPU memory, { time .monotonic ()= }  )
399+             self .cpu_offload_state_dict  =  _create_cpu_state_dict (
400+                 state_dict , pin_memory = True , share_memory = True 
401+             )
402+ 
403+         logger .debug (f"Staging the state_dict, { time .monotonic ()= }  )
404+         with  torch .cuda .stream (self .staging_stream ):
405+             self .cpu_offload_state_dict  =  _copy_state_dict (
406+                 state_dict ,
407+                 self .cpu_offload_state_dict ,
408+                 non_blocking = True ,
409+             )
410+             self .staging  =  True 
411+             self .staging_id  =  checkpoint_id 
412+ 
413+     def  wait_for_staging (self ) ->  None :
414+         if  not  self .staging_stream .query ():
415+             self .staging_stream .synchronize ()
416+         self .staging  =  False 
417+ 
418+     def  staging_results (self ) ->  Dict [str , Any ]:
419+         self .maybe_wait_for_staging ()
420+         return  self .cpu_offload_state_dict 
421+ 
385422    def  maybe_wait_for_staging (self ) ->  None :
386-         if  (
387-             self .enable_checkpoint 
388-             and  self .async_mode  ==  AsyncMode .ASYNC_WITH_PINNED_MEM 
389-             and  self .staging 
390-         ):
391-             if  not  self .staging_stream .query ():
392-                 self .staging_stream .synchronize ()
393- 
394-             def  sync_func ():
395-                 self .mp_queue_send .put_nowait (
396-                     (self .cpu_offload_state_dict , self .staging_id )
397-                 )
398- 
399-             # This may be a faster way to do zero-overhead checkpointing staging 
400-             # checkpointing but we need more thorough investigation before 
401-             # swithing to this method. 
402-             # self.my_thread = threading.Thread(target=func).start() 
403-             sync_func ()
404-             self .staging  =  False 
423+         if  self .enable_staging  and  self .staging :
424+             self .wait_for_staging ()
425+ 
426+             if  self .sending_to_checkpoint_mp :
427+                 # Copy the sync staging result to another process. 
428+                 def  sync_func ():
429+                     self .mp_queue_send .put_nowait (
430+                         (self .cpu_offload_state_dict , self .staging_id )
431+                     )
432+ 
433+                 # This may be a faster way to do zero-overhead checkpointing staging 
434+                 # checkpointing but we need more thorough investigation before 
435+                 # swithing to this method. 
436+                 # self.my_thread = threading.Thread(target=func).start() 
437+                 sync_func ()
438+                 self .sending_to_checkpoint_mp  =  False 
405439
406440    def  load (self , step : int  =  - 1 ) ->  bool :
407441        if  not  self .enable_checkpoint :
0 commit comments