File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change 30
30
import torch .utils .tensorboard as tensorboard
31
31
32
32
import mlflow # isort: skip
33
- import transformers
34
33
35
34
import wandb # isort: skip
36
35
from torch .distributed .checkpoint .state_dict import (
@@ -161,11 +160,11 @@ def __init__(
161
160
# Prepare model for training
162
161
# ----------------------------------
163
162
if args .enable_gradient_checkpointing :
164
- if not isinstance (model , transformers . PreTrainedModel ):
163
+ if not hasattr (model , "gradient_checkpointing_enable" ):
165
164
raise ValueError (
166
- "Gradient checkpointing is only supported for transformers models."
165
+ "Gradient checkpointing is only supported for Hugging Face models."
167
166
)
168
- model .gradient_checkpointing_enable (args .gradient_checkpointing_kwargs )
167
+ model .gradient_checkpointing_enable (args .gradient_checkpointing_kwargs ) # type: ignore
169
168
model = cast (torch .nn .Module , model )
170
169
model .to (self .device )
171
170
if is_distributed ():
You can’t perform that action at this time.
0 commit comments