Skip to content

Commit 1bd38af

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0a8ab01 commit 1bd38af

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,16 @@ def tearDown(self):
111111
def test_save(self, *_):
112112
"""Test that calling save() writes a checkpoint file to disk."""
113113
job_config = DummyJobConfig(job=self.dummy_job)
114+
ft_manager = mock.Mock()
115+
ft_manager.enabled = False
114116
manager = CheckpointManager(
115117
dummy_dataloader,
116118
dummy_model_parts,
117119
dummy_optimizers,
118120
dummy_lr_schedulers,
119121
{"trainer": self.trainer_state},
120122
job_config,
123+
ft_manager,
121124
)
122125
step = 20
123126
manager.save(curr_step=step, force=True)
@@ -141,13 +144,16 @@ def test_save(self, *_):
141144
def test_load(self, *_):
142145
"""Test that load() properly reads the checkpoint file from disk and restores state."""
143146
job_config = DummyJobConfig(job=self.dummy_job)
147+
ft_manager = mock.Mock()
148+
ft_manager.enabled = False
144149
manager = CheckpointManager(
145150
dummy_dataloader,
146151
dummy_model_parts,
147152
dummy_optimizers,
148153
dummy_lr_schedulers,
149154
{"trainer": self.trainer_state},
150155
job_config,
156+
ft_manager,
151157
)
152158
step = 30
153159
manager.save(curr_step=step, force=True)
@@ -179,13 +185,16 @@ def test_purge_stale_checkpoints_rank_zero(self, *_):
179185
"""
180186
job_config = DummyJobConfig(job=self.dummy_job)
181187
job_config.checkpoint.keep_latest_k = 3
188+
ft_manager = mock.Mock()
189+
ft_manager.enabled = False
182190
manager = CheckpointManager(
183191
dummy_dataloader,
184192
dummy_model_parts,
185193
dummy_optimizers,
186194
dummy_lr_schedulers,
187195
{"trainer": self.trainer_state},
188196
job_config,
197+
ft_manager,
189198
)
190199
steps = [10, 20, 30, 40, 50]
191200
for s in steps:
@@ -223,13 +232,16 @@ def test_purge_stale_checkpoints_rank_nonzero(self, *_):
223232
"""
224233
job_config = DummyJobConfig(job=self.dummy_job)
225234
job_config.checkpoint.keep_latest_k = 3
235+
ft_manager = mock.Mock()
236+
ft_manager.enabled = False
226237
manager = CheckpointManager(
227238
dummy_dataloader,
228239
dummy_model_parts,
229240
dummy_optimizers,
230241
dummy_lr_schedulers,
231242
{"trainer": self.trainer_state},
232243
job_config,
244+
ft_manager,
233245
)
234246
steps = [10, 20, 30, 40, 50]
235247
for s in steps:
@@ -260,13 +272,16 @@ def test_async_save_calls_async_wait(self, *_):
260272
# Set async_mode to "async" in the job configuration.
261273
job_config = DummyJobConfig(job=self.dummy_job)
262274
job_config.checkpoint.async_mode = "async"
275+
ft_manager = mock.Mock()
276+
ft_manager.enabled = False
263277
manager = CheckpointManager(
264278
dummy_dataloader,
265279
dummy_model_parts,
266280
dummy_optimizers,
267281
dummy_lr_schedulers,
268282
{"trainer": self.trainer_state},
269283
job_config,
284+
ft_manager,
270285
)
271286
# First save: should schedule an async save.
272287
manager.save(curr_step=10, force=False)

0 commit comments

Comments
 (0)