11import multiprocessing
2-
3- backup_ForkingPickler = multiprocessing .reduction .ForkingPickler
4- backup_dump = multiprocessing .reduction .dump
52import os
63from functools import partial
74
2522 reset_singletons ,
2623)
2724
25+ backup_ForkingPickler = multiprocessing .reduction .ForkingPickler
26+ backup_dump = multiprocessing .reduction .dump
27+
2828# (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY)
2929step_info_list = [(8 , 4 , 2 ), (3 , 4 , 2 ), (1 , 6 , 3 )]
3030ckpt_config_list = [
@@ -201,8 +201,8 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr
201201@pytest .mark .parametrize ("step_info" , step_info_list )
202202@pytest .mark .parametrize ("ckpt_config" , ckpt_config_list )
203203def test_ckpt_mm (step_info , ckpt_config , init_dist_and_model ): # noqa # pylint: disable=unused-import
204- from internlm .core .context import global_context as gpc
205204 from internlm .checkpoint .checkpoint_manager import CheckpointLoadMask
205+ from internlm .core .context import global_context as gpc
206206
207207 ckpt_config = Config (ckpt_config )
208208 total_step , checkpoint_every , oss_snapshot_freq = step_info
@@ -222,6 +222,8 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
222222 )
223223
224224 model , opim = init_dist_and_model
225+ gpc .config ._add_item ("ckpt" , dict ())
226+ gpc .config .ckpt ._add_item ("universal_ckpt" , dict (enable = False , aysnc_save = True , broadcast_load = False ))
225227 train_state = TrainState (gpc .config , None )
226228 if isinstance (opim , HybridZeroOptimizer ):
227229 print ("Is HybridZeroOptimizer!" , flush = True )
@@ -297,9 +299,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
297299
298300
299301def query_quit_file (rank , world_size = 2 ):
302+ from internlm .checkpoint .checkpoint_manager import CheckpointSaveType
300303 from internlm .core .context import global_context as gpc
301304 from internlm .initialize import initialize_distributed_env
302- from internlm .checkpoint .checkpoint_manager import CheckpointSaveType
303305
304306 ckpt_config = Config (
305307 dict (
@@ -348,8 +350,6 @@ def query_quit_file(rank, world_size=2):
348350
349351
350352def test_quit_siganl_handler (): # noqa # pylint: disable=unused-import
351- import multiprocessing
352-
353353 # we do hack here to workaround the bug of 3rd party library dill, which only occurs in this unittest:
354354 # https://github.com/uqfoundation/dill/issues/380
355355 multiprocessing .reduction .ForkingPickler = backup_ForkingPickler
0 commit comments