Skip to content

Commit 1a6534d

Browse files
authored
[tiny] Fix gradient checkpointing for Oumi trainer (#1778)
1 parent 15a6c68 commit 1a6534d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/oumi/core/trainers/oumi_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch.utils.tensorboard as tensorboard
3131

3232
import mlflow # isort: skip
33-
import transformers
3433

3534
import wandb # isort: skip
3635
from torch.distributed.checkpoint.state_dict import (
@@ -161,11 +160,11 @@ def __init__(
161160
# Prepare model for training
162161
# ----------------------------------
163162
if args.enable_gradient_checkpointing:
164-
if not isinstance(model, transformers.PreTrainedModel):
163+
if not hasattr(model, "gradient_checkpointing_enable"):
165164
raise ValueError(
166-
"Gradient checkpointing is only supported for transformers models."
165+
"Gradient checkpointing is only supported for Hugging Face models."
167166
)
168-
model.gradient_checkpointing_enable(args.gradient_checkpointing_kwargs)
167+
model.gradient_checkpointing_enable(args.gradient_checkpointing_kwargs) # type: ignore
169168
model = cast(torch.nn.Module, model)
170169
model.to(self.device)
171170
if is_distributed():

0 commit comments

Comments
 (0)