Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Nov 26, 2025

Summary

Through an in-depth investigation, I found that

  1. much of the model-preparation logic in the codebase is outdated, unnecessary, or incorrect.
  2. In addition, model preparation currently varies across trainers, with no unified approach, which implies maintenance headaches.

Goals of this PR

  1. Standardize model preparation for all stable trainers (SFT, GRPO, Reward, excluding DPO which is currently being refactored)
  2. Provide a correct, up-to-date, and well-documented model-preparation pipeline, derived from a thorough review of all cases covered in the referenced script.

Script used that covers the various cases

import numpy as np
import torch
import pandas as pd
from datasets import Dataset
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments

data = np.random.randint(0, 1000, (16, 64)).tolist()
dataset = Dataset.from_dict({"input_ids": data, "labels": data})

def human_readable_int(value: int) -> str:
    """Short human-readable numbers (1.2M, 340k...)."""
    if value >= 1_000_000:
        return f"{value / 1_000_000:.1f}M"
    if value >= 1_000:
        return f"{value / 1_000:.0f}k"
    return str(value)


def run_scenario(name, dtype, gc, quantized, use_lora):
    model_kwargs = {"device_map": "auto"}
    if quantized:
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
    else:
        model_kwargs["dtype"] = dtype

    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", **model_kwargs)

    if use_lora:
        model = get_peft_model(model, LoraConfig())
        model.enable_input_require_grads()  # ideally not needed, but see https://github.com/huggingface/transformers/issues/42489

    trainer = Trainer(
        model=model,
        args=TrainingArguments(gradient_checkpointing=gc),
        train_dataset=dataset,
    )
    trainer.train()

    model = trainer.model
    params = list(model.named_parameters())
    total_params = sum(p.numel() for _, p in params)
    trainable_params = sum(p.numel() for _, p in params if p.requires_grad)
    sample_dtype = params[0][1].dtype if params else "n/a"
    quant_method = getattr(model, "quantization_method", "")
    actual_quant = "4bit" if getattr(model, "is_loaded_in_4bit", False) else ("8bit" if getattr(model, "is_loaded_in_8bit", False) else (quant_method or "none"))
    is_lora_model = isinstance(model, PeftModel)
    observed = f"gc:{'on' if model.is_gradient_checkpointing else 'off'}, lora:{'yes' if is_lora_model else 'no'}, quant:{actual_quant}, dtype:{sample_dtype}"
    return {
        "Scenario": name,
        "Observed": observed,
        "Trainable": f"{human_readable_int(trainable_params)} / {human_readable_int(total_params)}",
    }


def main():
    scenarios = [
        {"name": "FFT",         "dtype": "auto",        "gc": False, "quantized": False, "use_lora": False},
        {"name": "FFT + FP16",  "dtype": torch.float16, "gc": False, "quantized": False, "use_lora": False},
        {"name": "FFT + GC",    "dtype": "auto",        "gc": True,  "quantized": False, "use_lora": False},
        {"name": "LoRA",        "dtype": "auto",        "gc": False, "quantized": False, "use_lora": True},
        {"name": "LoRA + GC",   "dtype": "auto",        "gc": True,  "quantized": False, "use_lora": True},
        {"name": "Q-LoRA",      "dtype": "auto",        "gc": False, "quantized": True,  "use_lora": True},
        {"name": "Q-LoRA + GC", "dtype": "auto",        "gc": True,  "quantized": True,  "use_lora": True},
    ]
    results = [run_scenario(**scenario) for scenario in scenarios]
    df = pd.DataFrame(results)
    print(df.to_markdown(index=False))


if __name__ == "__main__":
    main()
Scenario Observed Trainable
FFT gc:off, lora:no, quant:none, dtype:torch.bfloat16 596.0M / 596.0M
FFT + FP16 gc:off, lora:no, quant:none, dtype:torch.float16 596.0M / 596.0M
FFT + GC gc:on, lora:no, quant:none, dtype:torch.bfloat16 596.0M / 596.0M
LoRA gc:off, lora:yes, quant:none, dtype:torch.bfloat16 1.1M / 597.2M
LoRA + GC gc:on, lora:yes, quant:none, dtype:torch.bfloat16 1.1M / 597.2M
Q-LoRA gc:off, lora:yes, quant:4bit, dtype:torch.float16 1.1M / 377.0M
Q-LoRA + GC gc:on, lora:yes, quant:4bit, dtype:torch.float16 1.1M / 377.0M

@qgallouedec qgallouedec marked this pull request as ready for review November 29, 2025 00:54
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


@require_peft
@require_bitsandbytes
def test_peft_model_with_quantization(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we just align the test with the other tests, to make maintenance easier

TrainingArguments,
is_comet_available,
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to good practices, we shouldn't import this method, but I suggest that we make a special case, it's just for type hint.

from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we drop prepare_peft_model:

This part is now done directly the the trainer init:

trl/trl/models/utils.py

Lines 560 to 561 in 0726977

if isinstance(model, PeftModel) and peft_config is not None:
model = model.merge_and_unload()

The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory.

trl/trl/models/utils.py

Lines 563 to 584 in 0726977

# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break
# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args.gradient_checkpointing = False
elif args.gradient_checkpointing:
model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)

It’s not obvious from the current code (again: missing comments), but autocast_adapter_dtype=False is intended to force the adapter dtype to bfloat16 when using a quantized model, however, this behavior doesn’t seem to be functional at the moment. See here
This logic has now been moved into the trainers’ initialization, which is in my opinion clearer

trl/trl/models/utils.py

Lines 586 to 599 in 0726977

# Create PEFT model
if peft_config is not None:
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)
# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants