Skip to content

Commit d6f2a1a

Browse files
authored
Enabled repeat KV heads for AWQ models (quic#183)
* Enabled repeat KV heads for AWQ/GPTQ models Signed-off-by: Onkar Chougule <[email protected]> * undo location Signed-off-by: Onkar Chougule <[email protected]> * bugfix Signed-off-by: Onkar Chougule <[email protected]> * fixed CI bug, simplified replication script Signed-off-by: Onkar Chougule <[email protected]> --------- Signed-off-by: Onkar Chougule <[email protected]>
1 parent ef49dbd commit d6f2a1a

File tree

5 files changed

+85
-19
lines changed

5 files changed

+85
-19
lines changed

QEfficient/transformers/quantizers/auto.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
# ----------------------------------------------------------------------------
77

88
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING
9+
from transformers.quantizers.quantizer_awq import AwqQuantizer
10+
from transformers.quantizers.quantizer_gptq import GptqHfQuantizer
11+
from transformers.utils.quantization_config import AwqConfig, GPTQConfig
912

1013
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
1114
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
1215

1316
QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer}
1417
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig, "gptq": QEffGPTQConfig}
18+
DUPLICATE_AUTO_QUANTIZER_MAPPING = {"awq": AwqQuantizer, "gptq": GptqHfQuantizer}
19+
DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": AwqConfig, "gptq": GPTQConfig}
1520

1621

1722
def with_replaced_quantizers(func):
@@ -39,3 +44,26 @@ def wrapper(*args, **kwargs):
3944
return out
4045

4146
return wrapper
47+
48+
49+
def replace_transformers_quantizers():
50+
"""
51+
This method lets you import AWQ/GPTQ models on CPU without bypassing the
52+
rule of transformers of need to GPU.
53+
Just call this method before using
54+
`transformer.AutoModelForCausalLM.from_pretrained` and any AWQ/GPTQ model
55+
that can be supported by QEfficient will be loaded using CPU.
56+
"""
57+
AUTO_QUANTIZER_MAPPING.update(QEFF_AUTO_QUANTIZER_MAPPING)
58+
AUTO_QUANTIZATION_CONFIG_MAPPING.update(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING)
59+
60+
61+
# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers.
62+
# When a user imports QEfficient, these are always available.
63+
def undo_transformers_quantizers():
64+
"""
65+
This method is used to undo the effects on method `replace_transformers_quantizers`.
66+
After this is called, the transformers library will be used for loading AWQ/GPTQ models.
67+
"""
68+
AUTO_QUANTIZER_MAPPING.update(DUPLICATE_AUTO_QUANTIZER_MAPPING)
69+
AUTO_QUANTIZATION_CONFIG_MAPPING.update(DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING)

QEfficient/transformers/quantizers/quantizer_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
import torch
9-
from transformers.quantizers.quantizer_gptq import HfQuantizer
9+
from transformers.quantizers import HfQuantizer
1010
from transformers.utils.quantization_config import GPTQConfig
1111

1212
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ

QEfficient/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from QEfficient.transformers.quantizers.auto import ( # noqa: F401
9+
replace_transformers_quantizers,
10+
undo_transformers_quantizers,
11+
)
812
from QEfficient.utils._utils import ( # noqa: F401
913
check_and_assign_cache_dir,
1014
get_num_layers_from_config,

scripts/replicate_kv_head/replicate_kv_heads.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,59 @@
1111
from transformers import AutoModelForCausalLM, AutoTokenizer
1212

1313
from QEfficient import QEFFAutoModelForCausalLM, export
14+
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
15+
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
16+
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
17+
18+
19+
def duplicate_weights_for_linear_layer(
20+
layer: torch.nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
21+
):
22+
new_kv_heads = repeat * orig_kv_heads
23+
if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)):
24+
if head_dim % 8 != 0:
25+
raise ValueError(f"the value head_dim={head_dim} is not divisible by 8 which is \
26+
according to the assumption that model is 4-bit quantized.")
27+
if hidden_size % layer.group_size != 0:
28+
raise ValueError(f"The value of hidden_size={hidden_size} is not divisible by \
29+
K_proj.group_size={layer.group_size}")
30+
31+
# Duplication of quantized weights
32+
layer.qweight.data = torch.repeat_interleave(
33+
layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1
34+
).view(hidden_size, (new_kv_heads * head_dim) // 8)
35+
# Duplication of quantized zero points
36+
layer.qzeros.data = torch.repeat_interleave(
37+
layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8),
38+
repeat,
39+
1,
40+
).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8)
41+
# Duplication of quantization scales
42+
layer.scales.data = torch.repeat_interleave(
43+
layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim),
44+
repeat,
45+
1,
46+
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
47+
layer.out_features = layer.out_features * repeat
48+
else:
49+
layer.weight.data = torch.repeat_interleave(
50+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
51+
).view(new_kv_heads * head_dim, hidden_size)
1452

1553

1654
def main(args):
1755
# Load the model and tokenizer
1856
model_name = args.model_name
1957
model_base_name = model_name.split("/")[-1]
58+
# Replace quantizers for loading Quantized AWQ/GPTQ models on CPU.
59+
replace_transformers_quantizers()
2060
model = AutoModelForCausalLM.from_pretrained(
21-
model_name, # num_hidden_layers=2,
61+
model_name,
62+
num_hidden_layers=1,
2263
attn_implementation="eager",
2364
)
24-
65+
# Undo the effect of replace_transformers_quantizers
66+
undo_transformers_quantizers()
2567
tokenizer = AutoTokenizer.from_pretrained(model_name)
2668
inputs = tokenizer(args.prompt, return_tensors="pt")
2769

@@ -44,12 +86,8 @@ def main(args):
4486
attn = block.self_attn
4587
attn.num_key_value_heads = new_kv_heads
4688
attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads
47-
attn.k_proj.weight.data = torch.repeat_interleave(
48-
attn.k_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0
49-
).view(new_kv_heads * attn.head_dim, attn.hidden_size)
50-
attn.v_proj.weight.data = torch.repeat_interleave(
51-
attn.v_proj.weight.data.view(orig_kv_heads, attn.head_dim, attn.hidden_size), repeat, 0
52-
).view(new_kv_heads * attn.head_dim, attn.hidden_size)
89+
duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)
90+
duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)
5391

5492
# Generate modified outputs and tokens
5593
with torch.inference_mode():
@@ -60,6 +98,11 @@ def main(args):
6098
print("Original:", tokenizer.batch_decode(orig_tokens))
6199
print("Modified:", tokenizer.batch_decode(mod_tokens))
62100

101+
if not torch.all(orig_tokens == mod_tokens):
102+
raise RuntimeError(
103+
"Something went wrong while duplicating KV heads weights, output token don't match after modification"
104+
)
105+
63106
# Export the modified model
64107
q_model = QEFFAutoModelForCausalLM(model, model_name)
65108
export(

tests/transformers/models/test_causal_lm_models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import numpy as np
99
import pytest
1010
from transformers import AutoModelForCausalLM
11-
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING
1211

1312
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
1413
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
15-
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
16-
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
14+
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
1715
from QEfficient.utils import hf_download
1816
from QEfficient.utils._utils import load_hf_tokenizer
1917
from QEfficient.utils.constants import Constants
@@ -41,13 +39,6 @@
4139
]
4240

4341

44-
# TODO: Make this a fixture? Or better, always update the quantizer and config in transformers.
45-
# When a user imports QEfficient, these are always available.
46-
def replace_transformers_quantizers():
47-
AUTO_QUANTIZER_MAPPING.update({"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer})
48-
AUTO_QUANTIZATION_CONFIG_MAPPING.update({"awq": QEffAwqConfig, "gptq": QEffGPTQConfig})
49-
50-
5142
def load_causal_lm_model(model_config):
5243
"""
5344
Function to load model from huggingface and transform to KV model

0 commit comments

Comments
 (0)