Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 24 additions & 78 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,97 +1600,43 @@ def test_prompt_tuning(self):

@require_peft
@require_bitsandbytes
def test_peft_model_with_quantization(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we just align the test with the other tests, to make maintenance easier

"""SFTTrainer should not freeze layers of existing PeftModel.

This test simulates a realistic QLoRA scenario where a quantized base model is first converted to a PeftModel,
then passed to SFTTrainer. The issue was that prepare_model_for_kbit_training would freeze all parameters
including the LoRA adapters, making training impossible.
"""
def test_peft_with_quantization(self):
# Get the base model
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)

# Simulate a realistic QLoRA setup by mocking quantization attributes
# This mimics what happens when loading a model with load_in_4bit=True
model.is_loaded_in_4bit = True
model.is_loaded_in_8bit = False

# Verify that this triggers the is_qlora condition
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
assert is_qlora, "Model should be detected as QLoRA (quantized)"

# Create LoRA configuration suitable for QLoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"],
r=16,
lora_alpha=32,
lora_dropout=0.1,
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

# Convert the quantized model to a PeftModel (typical QLoRA workflow)
peft_model = get_peft_model(model, lora_config)

# Verify the quantization attributes are preserved on the PeftModel
assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag"
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

# Analyze parameters before SFTTrainer initialization
trainable_params_before = []
base_params_before = []
lora_params_before = []

for name, param in peft_model.named_parameters():
if param.requires_grad:
trainable_params_before.append(name)
if "lora" in name.lower():
lora_params_before.append(name)
else:
base_params_before.append(name)

# Ensure we have the expected parameter distribution for QLoRA
assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially"
assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters"
assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel"

# Initialize the trainer with the already configured PeftModel
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1)
trainer = SFTTrainer(model=peft_model, args=training_args, train_dataset=dataset)

# Analyze parameters after SFTTrainer initialization
trainable_params_after = []
lora_params_after = []

for name, param in trainer.model.named_parameters():
if param.requires_grad:
trainable_params_after.append(name)
if "lora" in name.lower():
lora_params_after.append(name)
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig())

# LoRA parameters should remain trainable
assert len(trainable_params_after) > 0, (
f"PeftModel should still have trainable parameters after SFTTrainer initialization. "
f"Found {len(trainable_params_after)} trainable params. "
f"This test fails without the fix for issue #3926."
)
# Save initial parameters to check they change during training
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

assert len(lora_params_after) > 0, (
f"LoRA adapter parameters should remain trainable. "
f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original."
)
trainer.train()

# Ensure the parameter counts are preserved (no additional freezing occurred)
assert len(trainable_params_before) == len(trainable_params_after), (
"Number of trainable parameters should not change after SFTTrainer initialization"
)
# Check that training completed successfully
assert trainer.state.log_history[-1]["train_loss"] is not None
assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None

# Verify that all original LoRA parameters are still trainable
assert set(lora_params_before) == set(lora_params_after), (
"All original LoRA parameters should remain trainable after SFTTrainer initialization"
)
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" not in n: # We expect the base model parameters to be the same
assert torch.allclose(param, new_param), f"Parameter {n} has changed"
elif "lora" in n: # We expect the peft parameters to be different
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
else:
raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")

@require_peft
def test_prompt_tuning_peft_model(self):
Expand Down
2 changes: 1 addition & 1 deletion trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class _SegmentTree:
A segment tree data structure that, when initialized as `_SegmentTree(maxval)`, efficiently finds the next larger
value for a given input within the range [1, maxval].

See [Fewer Truncations Improve Language Modeling](https://arxiv.org/abs/2404.10830) for more details.
See [Fewer Truncations Improve Language Modeling](https://huggingface.co/papers/2404.10830) for more details.
"""

def __init__(self, maxval: int):
Expand Down
68 changes: 38 additions & 30 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
import pandas as pd
import torch
import torch.utils.data
import transformers
from accelerate import logging
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
Expand All @@ -61,13 +59,14 @@
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we drop prepare_peft_model:

This part is now done directly the the trainer init:

trl/trl/models/utils.py

Lines 560 to 561 in 0726977

if isinstance(model, PeftModel) and peft_config is not None:
model = model.merge_and_unload()

The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory.

trl/trl/models/utils.py

Lines 563 to 584 in 0726977

# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break
# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args.gradient_checkpointing = False
elif args.gradient_checkpointing:
model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)

It’s not obvious from the current code (again: missing comments), but autocast_adapter_dtype=False is intended to force the adapter dtype to bfloat16 when using a quantized model, however, this behavior doesn’t seem to be functional at the moment. See here
This logic has now been moved into the trainers’ initialization, which is in my opinion clearer

trl/trl/models/utils.py

Lines 586 to 599 in 0726977

# Create PEFT model
if peft_config is not None:
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)
# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)

from ..models.utils import _ForwardRedirection
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import (
RepeatSampler,
create_model_from_path,
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
Expand All @@ -87,7 +86,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel
from peft import PeftConfig, PeftModel, get_peft_model

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
Expand Down Expand Up @@ -254,28 +253,14 @@ def __init__(
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")

# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
# Model
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str): # it's a str, but not "auto"
dtype = getattr(torch, dtype)
model_init_kwargs["dtype"] = dtype
else:
raise ValueError(
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
else:
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
Expand All @@ -290,9 +275,6 @@ def __init__(
else inspect.signature(model.get_base_model().forward).parameters.keys()
)

if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")
Expand All @@ -312,6 +294,30 @@ def __init__(
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
# If the model is already a PeftModel, we need to merge and unload it.
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
model = model.merge_and_unload()

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
# original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
# passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
# quantized models. See: https://github.com/huggingface/peft/issues/2889
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
for param in model.parameters():
if param.requires_grad:
param.data = param.data.to(torch.bfloat16)

# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
Expand Down Expand Up @@ -470,9 +476,11 @@ def __init__(
self.ref_model = None
else:
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if self.args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs)

# Disable dropout in the models
if args.disable_dropout:
Expand Down
45 changes: 22 additions & 23 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import torch
import torch.nn as nn
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from datasets import Dataset, IterableDataset
Expand All @@ -42,14 +41,14 @@
from transformers.utils import is_peft_available

from ..data_utils import is_conversational
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
from ..models import clone_chat_template, get_act_offloading_ctx_manager
from .base_trainer import BaseTrainer
from .reward_config import RewardConfig
from .utils import disable_dropout_in_model, get_config_model_id, pad, remove_none_values
from .utils import create_model_from_path, disable_dropout_in_model, get_config_model_id, pad, remove_none_values


if is_peft_available():
from peft import PeftConfig, PeftModel
from peft import PeftConfig, PeftModel, get_peft_model


logger = get_logger(__name__)
Expand Down Expand Up @@ -279,24 +278,13 @@ def __init__(
args = RewardConfig(f"{model_name}-Reward")

# Model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
model_init_kwargs["dtype"] = getattr(torch, dtype)
else:
raise ValueError(
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs)
else:
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
Expand All @@ -305,7 +293,7 @@ def __init__(

# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model_id)
processing_class = AutoTokenizer.from_pretrained(get_config_model_id(model.config))

# Handle pad token for processors or tokenizers
if args.eos_token is not None:
Expand Down Expand Up @@ -356,8 +344,19 @@ def __init__(
else:
peft_config.modules_to_save.append("lm_head")

if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)
if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
# If the model is already a PeftModel, we need to merge and unload it.
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
model = model.merge_and_unload()

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
model.enable_input_require_grads()

# Disable dropout in the model
if args.disable_dropout:
Expand Down
22 changes: 17 additions & 5 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
prepare_multimodal_messages,
truncate_dataset,
)
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
from ..models import clone_chat_template, get_act_offloading_ctx_manager
from .base_trainer import BaseTrainer
from .sft_config import SFTConfig
from .utils import (
Expand All @@ -63,7 +63,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel, PeftType
from peft import PeftConfig, PeftModel, PeftType, get_peft_model


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -693,12 +693,24 @@ def __init__(
else:
peft_config.modules_to_save.append("lm_head")

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
# If the model is already a PeftModel, we need to merge and unload it.
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
model = model.merge_and_unload()

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
model.enable_input_require_grads()

# In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the
# input. We store the number of these tokens so we can account for them correctly when calculating accuracy.
self.num_virtual_tokens = 0

if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)
if is_peft_available() and isinstance(model, PeftModel):
if model.active_adapter in model.peft_config:
peft_model_config = model.peft_config[model.active_adapter]
self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0)
Expand Down
Loading
Loading