Skip to content

Commit 6ce611c

Browse files
committed
Make possible to resume training with augmented model
1 parent 5e10e20 commit 6ce611c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

scripts/train_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def main(cfg_file_path):
131131
logger.warning(F'No model path to resume from. Using {TRAINED_MODEL_PTH_FILE}.')
132132
PICK_UP_MODEL = TRAINED_MODEL_PTH_FILE
133133
cfg.MODEL.WEIGHTS = PICK_UP_MODEL
134-
trainer = CocoTrainer(cfg)
135-
trainer.resume_or_load(resume=True)
134+
135+
resume_trigger=True
136136
else:
137137
PASSED_ZOO_MODEL = 'model_zoo_checkpoint_url' in MODEL_WEIGHTS.keys()
138138
INIT_MODEL_WEIGHTS = MODEL_WEIGHTS['init_model_weights'] if 'init_model_weights' in MODEL_WEIGHTS.keys() else False
@@ -154,8 +154,10 @@ def main(cfg_file_path):
154154
else:
155155
logger.info(f"Fine-tuning from {cfg.MODEL.WEIGHTS}")
156156

157-
trainer = AugmentedCocoTrainer(cfg) if DATA_AUGMENTATION else CocoTrainer(cfg)
158-
trainer.resume_or_load(resume=False)
157+
resume_trigger=False
158+
159+
trainer = AugmentedCocoTrainer(cfg) if DATA_AUGMENTATION else CocoTrainer(cfg)
160+
trainer.resume_or_load(resume=resume_trigger)
159161
trainer.train()
160162
written_files.append(os.path.join(WORKING_DIR, TRAINED_MODEL_PTH_FILE))
161163

0 commit comments

Comments
 (0)