|
81 | 81 | ################
|
82 | 82 | # Model & Tokenizer
|
83 | 83 | ################
|
84 |
| - quantization_config = get_quantization_config(model_args) |
85 | 84 | model_kwargs = dict(
|
86 | 85 | revision=model_args.model_revision,
|
87 | 86 | trust_remote_code=model_args.trust_remote_code,
|
88 | 87 | attn_implementation=model_args.attn_implementation,
|
89 | 88 | dtype=model_args.dtype,
|
90 | 89 | use_cache=False if training_args.gradient_checkpointing else True,
|
91 |
| - device_map=get_kbit_device_map() if quantization_config is not None else None, |
92 |
| - quantization_config=quantization_config, |
93 | 90 | )
|
| 91 | + quantization_config = get_quantization_config(model_args) |
| 92 | + if quantization_config is not None: |
| 93 | + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. |
| 94 | + model_kwargs["device_map"] = get_kbit_device_map() |
| 95 | + model_kwargs["quantization_config"] = quantization_config |
| 96 | + |
94 | 97 | training_args.model_init_kwargs = model_kwargs
|
95 | 98 |
|
96 | 99 | teacher_model_kwargs = dict(
|
|
99 | 102 | attn_implementation=model_args.attn_implementation,
|
100 | 103 | dtype=model_args.dtype,
|
101 | 104 | use_cache=True,
|
102 |
| - device_map=get_kbit_device_map() if quantization_config is not None else None, |
103 |
| - quantization_config=quantization_config, |
104 | 105 | )
|
| 106 | + if quantization_config is not None: |
| 107 | + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. |
| 108 | + model_kwargs["device_map"] = get_kbit_device_map() |
| 109 | + model_kwargs["quantization_config"] = quantization_config |
| 110 | + |
105 | 111 | training_args.teacher_model_init_kwargs = teacher_model_kwargs
|
106 | 112 |
|
107 | 113 | tokenizer = AutoTokenizer.from_pretrained(
|
|
0 commit comments