@@ -111,13 +111,16 @@ def tearDown(self):
111111 def test_save (self , * _ ):
112112 """Test that calling save() writes a checkpoint file to disk."""
113113 job_config = DummyJobConfig (job = self .dummy_job )
114+ ft_manager = mock .Mock ()
115+ ft_manager .enabled = False
114116 manager = CheckpointManager (
115117 dummy_dataloader ,
116118 dummy_model_parts ,
117119 dummy_optimizers ,
118120 dummy_lr_schedulers ,
119121 {"trainer" : self .trainer_state },
120122 job_config ,
123+ ft_manager ,
121124 )
122125 step = 20
123126 manager .save (curr_step = step , force = True )
@@ -141,13 +144,16 @@ def test_save(self, *_):
141144 def test_load (self , * _ ):
142145 """Test that load() properly reads the checkpoint file from disk and restores state."""
143146 job_config = DummyJobConfig (job = self .dummy_job )
147+ ft_manager = mock .Mock ()
148+ ft_manager .enabled = False
144149 manager = CheckpointManager (
145150 dummy_dataloader ,
146151 dummy_model_parts ,
147152 dummy_optimizers ,
148153 dummy_lr_schedulers ,
149154 {"trainer" : self .trainer_state },
150155 job_config ,
156+ ft_manager ,
151157 )
152158 step = 30
153159 manager .save (curr_step = step , force = True )
@@ -179,13 +185,16 @@ def test_purge_stale_checkpoints_rank_zero(self, *_):
179185 """
180186 job_config = DummyJobConfig (job = self .dummy_job )
181187 job_config .checkpoint .keep_latest_k = 3
188+ ft_manager = mock .Mock ()
189+ ft_manager .enabled = False
182190 manager = CheckpointManager (
183191 dummy_dataloader ,
184192 dummy_model_parts ,
185193 dummy_optimizers ,
186194 dummy_lr_schedulers ,
187195 {"trainer" : self .trainer_state },
188196 job_config ,
197+ ft_manager ,
189198 )
190199 steps = [10 , 20 , 30 , 40 , 50 ]
191200 for s in steps :
@@ -223,13 +232,16 @@ def test_purge_stale_checkpoints_rank_nonzero(self, *_):
223232 """
224233 job_config = DummyJobConfig (job = self .dummy_job )
225234 job_config .checkpoint .keep_latest_k = 3
235+ ft_manager = mock .Mock ()
236+ ft_manager .enabled = False
226237 manager = CheckpointManager (
227238 dummy_dataloader ,
228239 dummy_model_parts ,
229240 dummy_optimizers ,
230241 dummy_lr_schedulers ,
231242 {"trainer" : self .trainer_state },
232243 job_config ,
244+ ft_manager ,
233245 )
234246 steps = [10 , 20 , 30 , 40 , 50 ]
235247 for s in steps :
@@ -260,13 +272,16 @@ def test_async_save_calls_async_wait(self, *_):
260272 # Set async_mode to "async" in the job configuration.
261273 job_config = DummyJobConfig (job = self .dummy_job )
262274 job_config .checkpoint .async_mode = "async"
275+ ft_manager = mock .Mock ()
276+ ft_manager .enabled = False
263277 manager = CheckpointManager (
264278 dummy_dataloader ,
265279 dummy_model_parts ,
266280 dummy_optimizers ,
267281 dummy_lr_schedulers ,
268282 {"trainer" : self .trainer_state },
269283 job_config ,
284+ ft_manager ,
270285 )
271286 # First save: should schedule an async save.
272287 manager .save (curr_step = 10 , force = False )
0 commit comments