From a103219f3bc1908625042f0d7be8a3d90aa92ebb Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Thu, 25 Sep 2025 14:31:06 +0800 Subject: [PATCH] Support setting total_train_batch_size --- src/transformers/integrations/deepspeed.py | 25 ++++++++++++++++++++-- src/transformers/trainer.py | 23 ++++++++++++++++++++ src/transformers/training_args.py | 24 +++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 47d7a7ffcb5f..c5624a0c0cc4 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -182,9 +182,30 @@ def trainer_config_process(self, args, auto_find_batch_size=False): # First, override TrainingArguments based on DeepSpeed config to ensure compatibility self.override_training_args_from_deepspeed(args) + tp_size = self.config.get("tensor_parallel", {}).get("autotp_size", 1) + dp_world_size = args.world_size // tp_size + if args.total_train_batch_size > 0: + if ( + args.total_train_batch_size_mode == "strict" + and args.total_train_batch_size % (args.train_batch_size * dp_world_size) != 0 + ): + raise ValueError( + f"Can not find gradient_accumulation_steps to match total_train_batch_size of " + f"{args.total_train_batch_size} when train batch size is {args.train_batch_size} " + f"and dp world size is {dp_world_size}." + ) + new_gradient_accumulation_steps = max( + 1, round(args.total_train_batch_size / (args.train_batch_size * dp_world_size)) + ) + logger.info( + f"Updated gradient_accumulation_steps from {args.gradient_accumulation_steps} " + f"to {new_gradient_accumulation_steps}" + ) + args.gradient_accumulation_steps = new_gradient_accumulation_steps + # DeepSpeed does: - # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps + # train_batch_size = dp_world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps + train_batch_size = dp_world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps self.fill_match( "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0cd8fcf8cd14..ab5468927dce 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2350,6 +2350,27 @@ def get_total_train_batch_size(self, args) -> int: dp_world_size = args.world_size // self.get_tp_size() return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size + def update_gradient_accumulation_steps(self, args) -> int: + dp_world_size = args.world_size // self.get_tp_size() + + if ( + args.total_train_batch_size_mode == "strict" + and args.total_train_batch_size % (self._train_batch_size * dp_world_size) != 0 + ): + raise ValueError( + f"Can not find gradient_accumulation_steps to match total_train_batch_size of " + f"{args.total_train_batch_size} when train batch size is {self._train_batch_size} " + f"and dp world size is {dp_world_size}." + ) + new_gradient_accumulation_steps = max( + 1, round(args.total_train_batch_size / (self._train_batch_size * dp_world_size)) + ) + logger.info( + f"Updated gradient_accumulation_steps from {args.gradient_accumulation_steps} " + f"to {new_gradient_accumulation_steps}" + ) + args.gradient_accumulation_steps = new_gradient_accumulation_steps + def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): @@ -2380,6 +2401,8 @@ def _inner_training_loop( # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps + if not args.deepspeed and args.total_train_batch_size > 0: + self.update_gradient_accumulation_steps(args) total_train_batch_size = self.get_total_train_batch_size(args) ( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5e71f2a30a6d..24d726b157a3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -846,6 +846,30 @@ class TrainingArguments: default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."} ) + total_train_batch_size_mode: str = field( + default="nearly", + metadata={ + "help": ( + "The mode to keep total_train_batch_size the same in distributed training. " + "If mode is strict, The integer parameter gradient_accumulation_steps must be found " + "such that gradient_accumulation_steps*per_device_train_batch_size*world_size equals " + "total_train_batch_size. If mode is nearly, gradient_accumulation_steps will be selected " + "so that gradient_accumulation_steps*per_device_train_batch_size*world_size is closest " + "to total_train_batch_size" + ), + "choices": ["strict", "nearly"], + }, + ) + total_train_batch_size: int = field( + default=-1, + metadata={ + "help": ( + "Number of samples to accumulate before performing a backward/update pass. " + "If total_train_batch_size is positive, gradient_accumulation_steps will be ignored." + ) + }, + ) + gradient_accumulation_steps: int = field( default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},