1313from  dataclasses  import  dataclass , field 
1414from  io  import  BytesIO 
1515from  multiprocessing  import  get_context 
16- from  typing  import  Any , Dict , List , Union 
16+ from  typing  import  Any , Dict , List , Optional ,  Union 
1717
1818import  torch 
1919import  torch .distributed  as  dist 
2020import  torch .distributed .checkpoint  as  dcp 
2121import  torch .nn  as  nn 
22+ from  torch .distributed ._state_dict_utils  import  _copy_state_dict , _create_cpu_state_dict 
2223from  torch .distributed .checkpoint .state_dict  import  (
2324    get_model_state_dict ,
2425    set_model_state_dict ,
@@ -143,49 +144,28 @@ def __init__(
143144        lr_schedulers : SchedulersContainer ,
144145        states : Dict [str , Any ],
145146        job_config : JobConfig ,
147+         ft_manager : Optional [Any ] =  None ,
146148    ) ->  None :
147149        ckpt_config  =  job_config .checkpoint 
148150        self .enable_checkpoint  =  ckpt_config .enable_checkpoint 
149-         self .keep_latest_k  =  ckpt_config .keep_latest_k 
151+         self .ft_manager  =  ft_manager 
152+         self .enable_staging  =  (
153+             self .enable_checkpoint  and  async_mode  ==  AsyncMode .ASYNC_WITH_PINNED_MEM 
154+         ) or  self .ft_manager 
150155
151-         if  not  self .enable_checkpoint :
156+         if  not  self .enable_checkpoint   and   self . ft_manager   is   None :
152157            return 
153-         """ 
154-         Note: Pipeline Parallelism and Virtual Stages 
155- 
156-         1. even for simple PP schedules, there is a separate optimizer each PP rank. 
157-         rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model. 
158-         rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1. 
159-         When saving, these collide and one of them is lost.  Then when reloading, only one stage can 
160-         restore its optimizer states, others will error. 
161- 
162-             The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan 
163-             by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. 
164- 
165-         2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also 
166-         requiring us to reason about multiple 'optim' objects locally. 
167- 
168-             We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object 
169-             into one state dict before saving/loading. We rely on the individual state_dicts to not collide, 
170-             which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening 
171-             support described in (1). 
172- 
173-         3. LR schedulers also index model states like optimizers and would need to be flattened properly to support 
174-         resharding.  Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like 
175-         optimizers do, so it's hard to write a generic 'flattener' utility. 
176- 
177-             TODO: This is currently unsolved and needs a fix. 
178-         """ 
179-         self .states  =  states 
180158
181-         self .states .update (
182-             {
183-                 "model" : ModelWrapper (model_parts ),
184-                 "optimizer" : optimizers ,
185-                 "dataloader" : dataloader ,
186-             }
159+         self ._initialize_states (
160+             states , dataloader , model_parts , optimizers , lr_schedulers 
187161        )
188-         self .states .update (lr_schedulers .get_lr_scheduler_state ())
162+ 
163+         async_mode  =  ckpt_config .async_mode .lower ()
164+         self .staging  =  False 
165+         self .sending_to_checkpoint_mp  =  False 
166+         self .staging_id  =  None 
167+         self .cpu_offload_state_dict  =  None 
168+         self .staging_stream  =  torch .cuda .Stream () if  self .enable_staging  else  None 
189169
190170        self .folder  =  os .path .join (job_config .job .dump_folder , ckpt_config .folder )
191171        self .interval_type  =  (
@@ -199,11 +179,11 @@ def __init__(
199179        self .time_sync_result  =  None 
200180        self .pg  =  dist .new_group (backend = "gloo" )
201181
182+         self .keep_latest_k  =  ckpt_config .keep_latest_k 
202183        self .model_weights_only  =  ckpt_config .model_weights_only 
203184        self .export_dtype  =  TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
204185
205186        self .mp  =  None 
206-         async_mode  =  ckpt_config .async_mode .lower ()
207187        if  async_mode  ==  AsyncMode .DISABLED :
208188            self .async_mode  =  AsyncMode .DISABLED 
209189        elif  async_mode  ==  AsyncMode .ASYNC :
@@ -223,10 +203,6 @@ def __init__(
223203                daemon = True ,
224204            )
225205            self .mp .start ()
226-             self .cpu_offload_state_dict  =  None 
227-             self .staging  =  False 
228-             self .staging_id  =  None 
229-             self .staging_stream  =  torch .cuda .Stream ()
230206        else :
231207            raise  ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode }  )
232208
@@ -240,8 +216,61 @@ def __del__(self):
240216            self .mp .join ()
241217
242218    def  reset (self ) ->  None :
219+         # We need to stage the local state if another replicate joins during the 
220+         # first step. 
221+         if  self .ft_manager :
222+             self .cpu_staging (None )
243223        self .begin_time  =  time .monotonic ()
244224
225+     def  _initialize_states (
226+         self ,
227+         states : Dict [str , Any ],
228+         dataloader : DataLoader ,
229+         model_parts : List [nn .Module ],
230+         optimizers : OptimizersContainer ,
231+         lr_schedulers : SchedulersContainer ,
232+     ) ->  None :
233+         """ 
234+         Note: Pipeline Parallelism and Virtual Stages 
235+ 
236+         1. Even for simple PP schedules, there is a separate optimizer each PP rank. 
237+         rank0's optimizer would have a param_group[0] which refers to layers.0 in the 
238+         original model. rank1's would _also_ have a param_group[0], since it's index based, 
239+         but referring to layers.1. 
240+         When saving, these collide and one of them is lost.  Then when reloading, only one 
241+         stage can restore its optimizer states, others will error. 
242+ 
243+             The solution to this problem is optimizer flattening: it landed in #127071 
244+             and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' 
245+             kwarg to DCP functions called in the OptimizerContainer. 
246+ 
247+         2. With complex PP schedules, we have multiple model chunks per pp rank. This 
248+         compounds challenge (1) by also requiring us to reason about multiple 'optim' 
249+         objects locally. 
250+ 
251+             We solve this in the Model and Optimizer wrapper classes by flattening the 
252+             state dicts from each object into one state dict before saving/loading. 
253+             We rely on the individual state_dicts to not collide, which is gauranteed for 
254+             the model by correct pipeline splitting and for the optimizer by the flattening 
255+             support described in (1). 
256+ 
257+         3. LR schedulers also index model states like optimizers and would need to be 
258+         flattened properly to support resharding. Unfortunately, the implementations of 
259+         different lr_schedulers do not follow a clear pattern like optimizers do, so it's 
260+         hard to write a generic 'flattener' utility. 
261+ 
262+             TODO: This is currently unsolved and needs a fix. 
263+         """ 
264+         self .states  =  states 
265+         self .states .update (
266+             {
267+                 "model" : ModelWrapper (model_parts ),
268+                 "optimizer" : optimizers ,
269+                 "dataloader" : dataloader ,
270+             }
271+         )
272+         self .states .update (lr_schedulers .get_lr_scheduler_state ())
273+ 
245274    def  _create_checkpoint_id (self , step : int ) ->  str :
246275        return  os .path .join (self .folder , f"step-{ step }  )
247276
@@ -324,31 +353,8 @@ def _async_wait(self) -> None:
324353                self .async_future .result ()
325354
326355    def  _async_with_pinned_memory (self , checkpoint_id : str ) ->  None :
327-         try :
328-             from  torch .distributed ._state_dict_utils  import  (
329-                 _copy_state_dict ,
330-                 _create_cpu_state_dict ,
331-             )
332-         except  ImportError  as  e :
333-             raise  ImportError (
334-                 "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." 
335-             ) from  e 
336-         state_dict  =  dcp .state_dict_saver ._stateful_to_state_dict (self .states )
337-         if  self .cpu_offload_state_dict  is  None :
338-             logger .debug (f"Preparing the CPU memory, { time .monotonic ()= }  )
339-             self .cpu_offload_state_dict  =  _create_cpu_state_dict (
340-                 state_dict , pin_memory = True , share_memory = True 
341-             )
342- 
343-         logger .debug (f"Staging the state_dict, { time .monotonic ()= }  )
344-         with  torch .cuda .stream (self .staging_stream ):
345-             self .cpu_offload_state_dict  =  _copy_state_dict (
346-                 state_dict ,
347-                 self .cpu_offload_state_dict ,
348-                 non_blocking = True ,
349-             )
350-             self .staging  =  True 
351-             self .staging_id  =  checkpoint_id 
356+         self .cpu_staging (checkpoint_id )
357+         self .sending_to_checkpoint_mp  =  True 
352358
353359    def  save (self , curr_step : int , force : bool  =  False ) ->  None :
354360        """ 
@@ -358,6 +364,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
358364        for initial seed checkpoint. 
359365        """ 
360366        if  not  self ._should_save (curr_step , force ):
367+             if  self .ft_manager :
368+                 self .cpu_staging (None )
361369            return 
362370
363371        begin  =  time .monotonic ()
@@ -381,26 +389,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
381389            f"in { time .monotonic () -  begin :.2f}  
382390        )
383391
392+     def  cpu_staging (self , checkpoint_id : Optional [str ]) ->  None :
393+         """Offload state_dict to CPU memory""" 
394+         state_dict  =  dcp .state_dict_saver ._stateful_to_state_dict (self .states )
395+         if  self .cpu_offload_state_dict  is  None :
396+             logger .debug (f"Preparing the CPU memory, { time .monotonic ()= }  )
397+             self .cpu_offload_state_dict  =  _create_cpu_state_dict (
398+                 state_dict , pin_memory = True , share_memory = True 
399+             )
400+ 
401+         logger .debug (f"Staging the state_dict, { time .monotonic ()= }  )
402+         with  torch .cuda .stream (self .staging_stream ):
403+             self .cpu_offload_state_dict  =  _copy_state_dict (
404+                 state_dict ,
405+                 self .cpu_offload_state_dict ,
406+                 non_blocking = True ,
407+             )
408+             self .staging  =  True 
409+             self .staging_id  =  checkpoint_id 
410+ 
411+     def  wait_for_staging (self ) ->  None :
412+         if  not  self .staging_stream .query ():
413+             self .staging_stream .synchronize ()
414+         self .staging  =  False 
415+ 
416+     def  staging_results (self ) ->  Dict [str , Any ]:
417+         self .maybe_wait_for_staging ()
418+         return  self .cpu_offload_state_dict 
419+ 
384420    def  maybe_wait_for_staging (self ) ->  None :
385-         if  (
386-             self .enable_checkpoint 
387-             and  self .async_mode  ==  AsyncMode .ASYNC_WITH_PINNED_MEM 
388-             and  self .staging 
389-         ):
390-             if  not  self .staging_stream .query ():
391-                 self .staging_stream .synchronize ()
392- 
393-             def  sync_func ():
394-                 self .mp_queue_send .put_nowait (
395-                     (self .cpu_offload_state_dict , self .staging_id )
396-                 )
397- 
398-             # This may be a faster way to do zero-overhead checkpointing staging 
399-             # checkpointing but we need more thorough investigation before 
400-             # swithing to this method. 
401-             # self.my_thread = threading.Thread(target=func).start() 
402-             sync_func ()
403-             self .staging  =  False 
421+         if  self .enable_staging  and  self .staging :
422+             self .wait_for_staging ()
423+ 
424+             if  self .sending_to_checkpoint_mp :
425+                 # Copy the sync staging result to another process. 
426+                 def  sync_func ():
427+                     self .mp_queue_send .put_nowait (
428+                         (self .cpu_offload_state_dict , self .staging_id )
429+                     )
430+ 
431+                 # This may be a faster way to do zero-overhead checkpointing staging 
432+                 # checkpointing but we need more thorough investigation before 
433+                 # swithing to this method. 
434+                 # self.my_thread = threading.Thread(target=func).start() 
435+                 sync_func ()
436+                 self .sending_to_checkpoint_mp  =  False 
404437
405438    def  load (self , step : int  =  - 1 ) ->  bool :
406439        if  not  self .enable_checkpoint :
0 commit comments