Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 21 additions & 80 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,8 @@

if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION

IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False


if is_safetensors_available():
Expand Down Expand Up @@ -710,23 +705,13 @@ def __init__(
# BF16 + model parallelism in SageMaker: currently not supported, raise an error
if args.bf16:
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")

if IS_SAGEMAKER_MP_POST_1_10:
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
if args.fp16 != smp.state.cfg.fp16:
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
f"but FP16 provided in trainer argument is {args.fp16}, "
f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
else:
# smp < 1.10 does not support fp16 in trainer.
if hasattr(smp.state.cfg, "fp16"):
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
)
if args.fp16 != smp.state.cfg.fp16:
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
f"but FP16 provided in trainer argument is {args.fp16}, "
f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
if args.fp16 and args.device == torch.device("cpu") and not is_torch_greater_or_equal_than_2_3:
raise ValueError("Tried to use `fp16` but it is not supported on cpu. You need to have torch>=2.3")

Expand Down Expand Up @@ -1230,8 +1215,8 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
`create_scheduler`) in a subclass.
"""
self.create_optimizer()
if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
# If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
if is_sagemaker_mp_enabled() and smp.state.cfg.fp16:
# If fp16 is enabled, we unwrap the optimizer
optimizer = self.optimizer.optimizer
else:
optimizer = self.optimizer
Expand Down Expand Up @@ -2902,26 +2887,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if hasattr(self.args, "fp16") and self.args.fp16 is True:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
)
check_torch_load_is_safe()
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
# release memory
del state_dict
smp.resume_from_checkpoint(
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
)
elif self.is_fsdp_enabled:
load_fsdp_model(
self.accelerator.state.fsdp_plugin,
Expand Down Expand Up @@ -3015,26 +2983,12 @@ def _load_best_model(self):
):
has_been_loaded = True
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=self.state.best_model_checkpoint,
tag=WEIGHTS_NAME,
partial=False,
load_optimizer=False,
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
check_torch_load_is_safe()
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)

state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
smp.resume_from_checkpoint(
path=self.state.best_model_checkpoint,
tag=WEIGHTS_NAME,
partial=False,
load_optimizer=False,
)
else:
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
Expand Down Expand Up @@ -3511,20 +3465,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
# Optimizer checkpoint was saved with smp >= 1.10
def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))

else:
# Optimizer checkpoint was saved with smp < 1.10
def opt_load_hook(mod, opt):
if IS_SAGEMAKER_MP_POST_1_10:
opt.load_state_dict(
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
)
else:
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))

self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
Expand Down Expand Up @@ -4138,9 +4081,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
Path(os.path.join(output_dir, "user_content.pt")).touch()
# We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
elif getattr(self.accelerator, "parallelism_config", None) is not None:
if self.accelerator.should_save_model:
Expand Down