-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[WIP] Clean up model preparation #4577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
5d3b8d5
0b9283a
2967e31
f2b1ae6
9c952ab
b3f9317
577ed6a
56bfcaa
b72e75c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,15 +28,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.utils.data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import transformers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from accelerate import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from datasets import Dataset, IterableDataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch import nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import DataLoader, Sampler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AutoConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AutoModelForSequenceClassification, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AutoProcessor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AutoTokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -61,13 +59,14 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..extras.profiling import profiling_context, profiling_decorator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..extras.vllm_client import VLLMClient | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..import_utils import is_liger_kernel_available, is_vllm_available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we drop This part is now done directly the the trainer init: Lines 560 to 561 in 0726977
The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory. Lines 563 to 584 in 0726977
It’s not obvious from the current code (again: missing comments), but Lines 586 to 599 in 0726977
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.utils import _ForwardRedirection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .base_trainer import BaseTrainer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .callbacks import SyncRefModelCallback | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .grpo_config import GRPOConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RepeatSampler, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| create_model_from_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| disable_dropout_in_model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ensure_master_addr_port, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| entropy_from_logits, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -87,7 +86,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_peft_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from peft import PeftConfig, PeftModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from peft import PeftConfig, PeftModel, get_peft_model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_liger_kernel_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -254,28 +253,14 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_name = model_name.split("/")[-1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = GRPOConfig(f"{model_name}-GRPO") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Trained model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs = args.model_init_kwargs or {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(model, str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_id = model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype = model_init_kwargs.get("dtype", "auto") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pass # dtype is already a torch.dtype or "auto" or None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(dtype, str): # it's a str, but not "auto" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype = getattr(torch, dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs["dtype"] = dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = AutoConfig.from_pretrained(model_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| architecture = getattr(transformers, config.architectures[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = architecture.from_pretrained(model_id, **model_init_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs = args.model_init_kwargs or {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Special case for DeepSpeed: requires device_map=None ("auto" fails) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.distributed_state.distributed_type == "DEEPSPEED": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs["device_map"] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = create_model_from_path(model, **model_init_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_id = get_config_model_id(model.config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.model_init_kwargs is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -290,9 +275,6 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else inspect.signature(model.get_base_model().forward).parameters.keys() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = prepare_peft_model(model, peft_config, args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Processing class | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if processing_class is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -312,6 +294,30 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.pad_token_id = tokenizer.pad_token_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.eos_token_id = tokenizer.eos_token_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If the model is already a PeftModel, we need to merge and unload it. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = model.merge_and_unload() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create PEFT model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if peft_config is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = get_peft_model(model, peft_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model.enable_input_require_grads() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # quantized models. See: https://github.com/huggingface/peft/issues/2889 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for param in model.parameters(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if param.requires_grad: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Reward functions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(reward_funcs, list): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reward_funcs = [reward_funcs] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -470,9 +476,11 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.ref_model = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For deepspeed, fsdp or non-distributed models, create a reference model from scratch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = AutoConfig.from_pretrained(model_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| architecture = getattr(transformers, config.architectures[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs = args.model_init_kwargs or {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Special case for DeepSpeed: requires device_map=None ("auto" fails) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.args.distributed_state.distributed_type == "DEEPSPEED": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_init_kwargs["device_map"] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Disable dropout in the models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.disable_dropout: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we just align the test with the other tests, to make maintenance easier