Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,9 +1493,12 @@ def split_dtensor_by_axis(dtensor, axis=0):
raise ValueError(f"unsupported type: {type(dtensors)}")
return global_micro_batchs

def optimizer_step(self, args, parameters_list=None):
def optimizer_step(self, args, model, parameters_list=None):
if parameters_list is None:
parameters_list = []

optimizer_was_run = True
if args.enable_auto_parallel and self.args.offload_optim:
if not args.enable_auto_parallel and self.args.offload_optim:
self._reload_optimizer()

if self.do_grad_scaling:
Expand All @@ -1515,11 +1518,12 @@ def optimizer_step(self, args, parameters_list=None):
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)
elif isinstance(self.optimizer, HybridParallelOptimizer):
parameters_list = [t if t.is_contiguous() else t.contiguous() for t in parameters_list]
self.optimizer._step(parameters_list)
else:
self.optimizer.step()

if args.enable_auto_parallel and self.args.offload_optim:
if not args.enable_auto_parallel and self.args.offload_optim:
self._offload_optimizer()

if optimizer_was_run:
Expand All @@ -1533,6 +1537,10 @@ def optimizer_step(self, args, parameters_list=None):

if not args.enable_auto_parallel and (args.release_grads or enable_release_grads):
self.optimizer.clear_grad(set_to_zero=False)
if args.pipeline_parallel_degree > 1:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._clear_grad_storage()
else:
self.optimizer.clear_grad()

Expand Down Expand Up @@ -1819,6 +1827,8 @@ def _inner_training_loop(
if not self.args.enable_auto_parallel:
with sync_context:
if "step_control" in inspect.signature(self.training_step).parameters:
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
else:
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)
Expand Down Expand Up @@ -1943,7 +1953,7 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
)

self.optimizer_step(args, parameters_list=parameters_list)
self.optimizer_step(args, model=model, parameters_list=parameters_list)

self.timers and self.timers("optimizer-step").stop()

Expand Down Expand Up @@ -2120,7 +2130,7 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
if self.args.enable_auto_parallel or self.args.world_size <= 1:
return paddle.io.BatchSampler(
dataset=self.train_dataset,
shuffle=False,
shuffle=shuffle,
batch_size=total_batch_size,
drop_last=self.args.dataloader_drop_last,
)
Expand Down Expand Up @@ -2801,7 +2811,6 @@ def _wrap_distributed_optimizer(self, optimizer):
and self.args.moe_sharding_parallel_degree >= 1
and self.args.expert_parallel_degree > 1
and self.args.sharding_parallel_degree > 1
and not self.args.reorder_pipeline_priority
):
from ..utils import MoEHybridParallelOptimizer

Expand Down
Loading