File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 1515from  torch .distributed ._tools .fsdp2_mem_tracker  import  FSDPMemTracker 
1616from  torch .testing ._internal .distributed .fake_pg  import  FakeStore 
1717
18+ from  torchtitan .components .ft  import  init_ft_manager 
1819from  torchtitan .components .optimizer  import  build_lr_schedulers , build_optimizers 
1920from  torchtitan .config_manager  import  JobConfig 
2021from  torchtitan .distributed  import  ParallelDims , utils  as  dist_utils 
@@ -102,7 +103,6 @@ def estimate_memory(job_config: JobConfig):
102103        if  not  job_config .memory_estimation .disable_fake_mode 
103104        else  contextlib .nullcontext ()
104105    ):
105- 
106106        logger .info (
107107            f"Building { train_spec .name } { job_config .model .flavor } { model_config }  
108108        )
@@ -122,7 +122,8 @@ def estimate_memory(job_config: JobConfig):
122122        model .train ()
123123
124124        # build optimizer after applying parallelisms to the model 
125-         optimizers  =  build_optimizers ([model ], job_config )
125+         ft_manager  =  init_ft_manager (job_config )
126+         optimizers  =  build_optimizers ([model ], job_config , ft_manager )
126127        lr_schedulers  =  build_lr_schedulers (optimizers .optimizers , job_config )
127128        # Post optimizer step model converters hook. 
128129        # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments