Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@

try:
from .utils.zero_cost_checkpoint import (
NonZCCEMACallback,
ZeroCostCheckpointCallback,
ZeroCostCheckpointManager,
get_fused_param_mappings,
)
except (ImportError, ModuleNotFoundError):
ZeroCostCheckpointManager, get_fused_param_mappings = None, None
ZeroCostCheckpointManager, NonZCCEMACallback, get_fused_param_mappings = None, None, None
from .utils.helper import ( # nested_truncate,
broadcast_dataset_rank0_model,
broadcast_dp_optimizer,
Expand Down Expand Up @@ -854,6 +855,9 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):

logger.info("Create zero cost checkpoint manager done.")

def add_non_zcc_ema_callback(self, resume_from_checkpoint):
self.add_callback(NonZCCEMACallback(resume_from_checkpoint, self.args, self.sharding_io))

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -1026,6 +1030,8 @@ def train(

if self.args.enable_zero_cost_checkpoint:
self.create_zcc_manager(model, resume_from_checkpoint)
elif self.args.zcc_save_ema_coef is not None:
self.add_non_zcc_ema_callback(resume_from_checkpoint)

logger.info(f"{self.runtime_timer.log()}")
logger.info("***** Running training *****")
Expand Down Expand Up @@ -1450,7 +1456,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
self.state.epoch = epoch + (step + 1) / steps_in_epoch

# For ZCC EMA
if self.args.enable_zero_cost_checkpoint:
if self.args.enable_zero_cost_checkpoint or self.args.zcc_save_ema_coef is not None:
tr_loss_for_zcc = tr_loss.clone()
dist.all_reduce(
tr_loss_for_zcc, dist.ReduceOp.SUM
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,7 +2049,7 @@ def is_context_parallel_supported():
assert (
self.save_steps % self.zcc_ema_interval == 0
), f"save_steps[{self.save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]"
if self.zcc_save_ema_coef is not None:
if self.enable_zero_cost_checkpoint and self.zcc_save_ema_coef is not None:
assert (
self.zcc_workers_num == 1
), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now."
Expand Down
83 changes: 83 additions & 0 deletions paddlenlp/trainer/utils/zero_cost_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,3 +1019,86 @@ def manage_offload_chunk(self):
logger.info(
f"[ZCC Worker{self.worker_id}] All numel: {self.all_numel}, Offload chunks: {self.offload_chunks}, Chunk size: {self.chunk_size_in_numel}]"
)


class EMABuffer:
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True"
self.master_weights = {}
self.model_params = {}
self.args = args
self.sharding_io = sharding_io
self.offload = offload
if resume_from_checkpoint is not None:
self._load(resume_from_checkpoint)

def _ema_path(self, base_path):
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = path.replace("optimizer", "ema")
return os.path.join(base_path, path)

def _load(self, resume_from_checkpoint):
ema_path = self._ema_path(resume_from_checkpoint)
if not os.path.exists(ema_path):
return

logger.info(f"Loading EMA checkpoint from {resume_from_checkpoint} ...")
with device_guard("cpu"):
ema_state_dict = paddle.load(ema_path)
logger.info(f"Load EMA checkpoint from {resume_from_checkpoint} done")

self.master_weights = ema_state_dict.pop("master_weights")
self.model_params = ema_state_dict

def save(self, global_step):
base_path = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
ema_path = self._ema_path(base_path)
ema_state_dict = {"master_weights": self.master_weights}
ema_state_dict.update(self.model_params)
os.makedirs(base_path, exist_ok=True)
logger.info(f"Saving EMA checkpoint to {base_path} ...")
paddle.save(ema_state_dict, ema_path)
logger.info(f"Save EMA checkpoint to {base_path} done")

def ema_accumulate(self, global_step, loss, ema_loss_threshold):
if ema_loss_threshold is None or loss < ema_loss_threshold:
logger.info(f"EMA accumulating for step {global_step} ...")
self._ema_impl(
state_dict=self.sharding_io.optimizer.state_dict()["master_weights"],
ema_state_dict=self.master_weights,
)
self._ema_impl(
state_dict=self.sharding_io.manipulate_state_dict_and_config(
unwrap_model(self.sharding_io.model),
merge_tensor_parallel=False,
)[0],
ema_state_dict=self.model_params,
)
logger.info(f"EMA accumulate done for step {global_step}")

def _ema_impl(self, state_dict, ema_state_dict):
ema_coef = self.args.zcc_save_ema_coef
for k, v in state_dict.items():
if k in ema_state_dict:
ema_tensor = ema_state_dict[k]
ema_tensor = ema_coef * ema_tensor.cuda() + (1 - ema_coef) * v.cuda()
ema_tensor.name = v.name
v = ema_tensor
del ema_tensor

if self.offload:
v_pin = v.pin_memory()
v_pin.name = v.name
v = v_pin
ema_state_dict[k] = v


class NonZCCEMACallback(TrainerCallback):
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
self.buffer = EMABuffer(resume_from_checkpoint, args, sharding_io, offload)

def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.zcc_ema_interval == 0:
self.buffer.ema_accumulate(state.global_step, state.loss, args.zcc_ema_loss_threshold)
if control.should_save:
self.buffer.save(state.global_step)
Loading