-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[WIP] Clean up model preparation #4577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| @require_peft | ||
| @require_bitsandbytes | ||
| def test_peft_model_with_quantization(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we just align the test with the other tests, to make maintenance easier
| TrainingArguments, | ||
| is_comet_available, | ||
| ) | ||
| from transformers.models.auto.auto_factory import _BaseAutoModelClass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to good practices, we shouldn't import this method, but I suggest that we make a special case, it's just for type hint.
| from ..extras.vllm_client import VLLMClient | ||
| from ..import_utils import is_liger_kernel_available, is_vllm_available | ||
| from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation | ||
| from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we drop prepare_peft_model:
This part is now done directly the the trainer init:
Lines 560 to 561 in 0726977
| if isinstance(model, PeftModel) and peft_config is not None: | |
| model = model.merge_and_unload() |
The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory.
Lines 563 to 584 in 0726977
| # Handle quantized models (QLoRA) | |
| is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) | |
| is_sharded_qlora = False | |
| if getattr(model, "is_loaded_in_4bit", False): | |
| # Check if model is sharded (FSDP/DS-Zero3) | |
| for _, param in model.named_parameters(): | |
| if param.__class__.__name__ == "Params4bit": | |
| is_sharded_qlora = param.data.device.type in {"cpu", "meta"} | |
| break | |
| # Prepare model for kbit training if needed | |
| if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel): | |
| model = prepare_model_for_kbit_training( | |
| model, | |
| use_gradient_checkpointing=args.gradient_checkpointing, | |
| gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {}, | |
| ) | |
| # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training | |
| args.gradient_checkpointing = False | |
| elif args.gradient_checkpointing: | |
| model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs) |
It’s not obvious from the current code (again: missing comments), but autocast_adapter_dtype=False is intended to force the adapter dtype to bfloat16 when using a quantized model, however, this behavior doesn’t seem to be functional at the moment. See here
This logic has now been moved into the trainers’ initialization, which is in my opinion clearer
Lines 586 to 599 in 0726977
| # Create PEFT model | |
| if peft_config is not None: | |
| if ( | |
| version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 | |
| and getattr(model, "is_loaded_in_4bit", False) | |
| and is_sharded_qlora | |
| ): | |
| model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) | |
| else: | |
| model = get_peft_model(model, peft_config) | |
| # Handle bf16 casting for 4-bit models | |
| if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: | |
| peft_module_casting_to_bf16(model) |
Summary
Through an in-depth investigation, I found that
Goals of this PR
Script used that covers the various cases