Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,30 @@ def trainer_config_process(self, args, auto_find_batch_size=False):
# First, override TrainingArguments based on DeepSpeed config to ensure compatibility
self.override_training_args_from_deepspeed(args)

tp_size = self.config.get("tensor_parallel", {}).get("autotp_size", 1)
dp_world_size = args.world_size // tp_size
if args.total_train_batch_size > 0:
if (
args.total_train_batch_size_mode == "strict"
and args.total_train_batch_size % (args.train_batch_size * dp_world_size) != 0
):
raise ValueError(
f"Can not find gradient_accumulation_steps to match total_train_batch_size of "
f"{args.total_train_batch_size} when train batch size is {args.train_batch_size} "
f"and dp world size is {dp_world_size}."
)
new_gradient_accumulation_steps = max(
1, round(args.total_train_batch_size / (args.train_batch_size * dp_world_size))
)
logger.info(
f"Updated gradient_accumulation_steps from {args.gradient_accumulation_steps} "
f"to {new_gradient_accumulation_steps}"
)
args.gradient_accumulation_steps = new_gradient_accumulation_steps

# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
# train_batch_size = dp_world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = dp_world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
self.fill_match(
"train_micro_batch_size_per_gpu",
args.per_device_train_batch_size,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,27 @@ def get_total_train_batch_size(self, args) -> int:
dp_world_size = args.world_size // self.get_tp_size()
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size

def update_gradient_accumulation_steps(self, args) -> int:
dp_world_size = args.world_size // self.get_tp_size()

if (
args.total_train_batch_size_mode == "strict"
and args.total_train_batch_size % (self._train_batch_size * dp_world_size) != 0
):
raise ValueError(
f"Can not find gradient_accumulation_steps to match total_train_batch_size of "
f"{args.total_train_batch_size} when train batch size is {self._train_batch_size} "
f"and dp world size is {dp_world_size}."
)
new_gradient_accumulation_steps = max(
1, round(args.total_train_batch_size / (self._train_batch_size * dp_world_size))
)
logger.info(
f"Updated gradient_accumulation_steps from {args.gradient_accumulation_steps} "
f"to {new_gradient_accumulation_steps}"
)
args.gradient_accumulation_steps = new_gradient_accumulation_steps

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
Expand Down Expand Up @@ -2380,6 +2401,8 @@ def _inner_training_loop(
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if not args.deepspeed and args.total_train_batch_size > 0:
self.update_gradient_accumulation_steps(args)
total_train_batch_size = self.get_total_train_batch_size(args)

(
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,30 @@ class TrainingArguments:
default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."}
)

total_train_batch_size_mode: str = field(
default="nearly",
metadata={
"help": (
"The mode to keep total_train_batch_size the same in distributed training. "
"If mode is strict, The integer parameter gradient_accumulation_steps must be found "
"such that gradient_accumulation_steps*per_device_train_batch_size*world_size equals "
"total_train_batch_size. If mode is nearly, gradient_accumulation_steps will be selected "
"so that gradient_accumulation_steps*per_device_train_batch_size*world_size is closest "
"to total_train_batch_size"
),
"choices": ["strict", "nearly"],
},
)
total_train_batch_size: int = field(
default=-1,
metadata={
"help": (
"Number of samples to accumulate before performing a backward/update pass. "
"If total_train_batch_size is positive, gradient_accumulation_steps will be ignored."
)
},
)

gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
Expand Down