diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 470b4b12039f..4bcf587980a2 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -187,12 +187,13 @@ try: from .utils.zero_cost_checkpoint import ( + NonZCCEMACallback, ZeroCostCheckpointCallback, ZeroCostCheckpointManager, get_fused_param_mappings, ) except (ImportError, ModuleNotFoundError): - ZeroCostCheckpointManager, get_fused_param_mappings = None, None + ZeroCostCheckpointManager, NonZCCEMACallback, get_fused_param_mappings = None, None, None from .utils.helper import ( # nested_truncate, broadcast_dataset_rank0_model, broadcast_dp_optimizer, @@ -854,6 +855,9 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): logger.info("Create zero cost checkpoint manager done.") + def add_non_zcc_ema_callback(self, resume_from_checkpoint): + self.add_callback(NonZCCEMACallback(resume_from_checkpoint, self.args, self.sharding_io)) + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -1026,6 +1030,8 @@ def train( if self.args.enable_zero_cost_checkpoint: self.create_zcc_manager(model, resume_from_checkpoint) + elif self.args.zcc_save_ema_coef is not None: + self.add_non_zcc_ema_callback(resume_from_checkpoint) logger.info(f"{self.runtime_timer.log()}") logger.info("***** Running training *****") @@ -1450,7 +1456,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.state.epoch = epoch + (step + 1) / steps_in_epoch # For ZCC EMA - if self.args.enable_zero_cost_checkpoint: + if self.args.enable_zero_cost_checkpoint or self.args.zcc_save_ema_coef is not None: tr_loss_for_zcc = tr_loss.clone() dist.all_reduce( tr_loss_for_zcc, dist.ReduceOp.SUM diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7599c4d49cc8..b4e14222b39b 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -2049,7 +2049,7 @@ def is_context_parallel_supported(): assert ( self.save_steps % self.zcc_ema_interval == 0 ), f"save_steps[{self.save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]" - if self.zcc_save_ema_coef is not None: + if self.enable_zero_cost_checkpoint and self.zcc_save_ema_coef is not None: assert ( self.zcc_workers_num == 1 ), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now." diff --git a/paddlenlp/trainer/utils/zero_cost_checkpoint.py b/paddlenlp/trainer/utils/zero_cost_checkpoint.py index 7e652b0b1cda..807758e86d9c 100644 --- a/paddlenlp/trainer/utils/zero_cost_checkpoint.py +++ b/paddlenlp/trainer/utils/zero_cost_checkpoint.py @@ -1019,3 +1019,86 @@ def manage_offload_chunk(self): logger.info( f"[ZCC Worker{self.worker_id}] All numel: {self.all_numel}, Offload chunks: {self.offload_chunks}, Chunk size: {self.chunk_size_in_numel}]" ) + + +class EMABuffer: + def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True): + assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True" + self.master_weights = {} + self.model_params = {} + self.args = args + self.sharding_io = sharding_io + self.offload = offload + if resume_from_checkpoint is not None: + self._load(resume_from_checkpoint) + + def _ema_path(self, base_path): + path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) + path = path.replace("optimizer", "ema") + return os.path.join(base_path, path) + + def _load(self, resume_from_checkpoint): + ema_path = self._ema_path(resume_from_checkpoint) + if not os.path.exists(ema_path): + return + + logger.info(f"Loading EMA checkpoint from {resume_from_checkpoint} ...") + with device_guard("cpu"): + ema_state_dict = paddle.load(ema_path) + logger.info(f"Load EMA checkpoint from {resume_from_checkpoint} done") + + self.master_weights = ema_state_dict.pop("master_weights") + self.model_params = ema_state_dict + + def save(self, global_step): + base_path = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") + ema_path = self._ema_path(base_path) + ema_state_dict = {"master_weights": self.master_weights} + ema_state_dict.update(self.model_params) + os.makedirs(base_path, exist_ok=True) + logger.info(f"Saving EMA checkpoint to {base_path} ...") + paddle.save(ema_state_dict, ema_path) + logger.info(f"Save EMA checkpoint to {base_path} done") + + def ema_accumulate(self, global_step, loss, ema_loss_threshold): + if ema_loss_threshold is None or loss < ema_loss_threshold: + logger.info(f"EMA accumulating for step {global_step} ...") + self._ema_impl( + state_dict=self.sharding_io.optimizer.state_dict()["master_weights"], + ema_state_dict=self.master_weights, + ) + self._ema_impl( + state_dict=self.sharding_io.manipulate_state_dict_and_config( + unwrap_model(self.sharding_io.model), + merge_tensor_parallel=False, + )[0], + ema_state_dict=self.model_params, + ) + logger.info(f"EMA accumulate done for step {global_step}") + + def _ema_impl(self, state_dict, ema_state_dict): + ema_coef = self.args.zcc_save_ema_coef + for k, v in state_dict.items(): + if k in ema_state_dict: + ema_tensor = ema_state_dict[k] + ema_tensor = ema_coef * ema_tensor.cuda() + (1 - ema_coef) * v.cuda() + ema_tensor.name = v.name + v = ema_tensor + del ema_tensor + + if self.offload: + v_pin = v.pin_memory() + v_pin.name = v.name + v = v_pin + ema_state_dict[k] = v + + +class NonZCCEMACallback(TrainerCallback): + def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True): + self.buffer = EMABuffer(resume_from_checkpoint, args, sharding_io, offload) + + def on_step_end(self, args, state, control, **kwargs): + if state.global_step % args.zcc_ema_interval == 0: + self.buffer.ema_accumulate(state.global_step, state.loss, args.zcc_ema_loss_threshold) + if control.should_save: + self.buffer.save(state.global_step)