1111from transformers import AutoModelForCausalLM , AutoTokenizer
1212
1313from QEfficient import QEFFAutoModelForCausalLM , export
14+ from QEfficient .transformers .quantizers .auto import replace_transformers_quantizers , undo_transformers_quantizers
15+ from QEfficient .transformers .quantizers .awq import WQLinear_GEMM
16+ from QEfficient .transformers .quantizers .gptq import QuantLinearGPTQ
17+
18+
19+ def duplicate_weights_for_linear_layer (
20+ layer : torch .nn .Module , orig_kv_heads : int , repeat : int , head_dim : int , hidden_size : int
21+ ):
22+ new_kv_heads = repeat * orig_kv_heads
23+ if isinstance (layer , (WQLinear_GEMM , QuantLinearGPTQ )):
24+ if head_dim % 8 != 0 :
25+ raise ValueError (f"the value head_dim={ head_dim } is not divisible by 8 which is \
26+ according to the assumption that model is 4-bit quantized." )
27+ if hidden_size % layer .group_size != 0 :
28+ raise ValueError (f"The value of hidden_size={ hidden_size } is not divisible by \
29+ K_proj.group_size={ layer .group_size } " )
30+
31+ # Duplication of quantized weights
32+ layer .qweight .data = torch .repeat_interleave (
33+ layer .qweight .data .view (hidden_size , orig_kv_heads , head_dim // 8 ), repeat , 1
34+ ).view (hidden_size , (new_kv_heads * head_dim ) // 8 )
35+ # Duplication of quantized zero points
36+ layer .qzeros .data = torch .repeat_interleave (
37+ layer .qzeros .data .view (hidden_size // layer .group_size , orig_kv_heads , head_dim // 8 ),
38+ repeat ,
39+ 1 ,
40+ ).view (hidden_size // layer .group_size , (new_kv_heads * head_dim ) // 8 )
41+ # Duplication of quantization scales
42+ layer .scales .data = torch .repeat_interleave (
43+ layer .scales .data .view (hidden_size // layer .group_size , orig_kv_heads , head_dim ),
44+ repeat ,
45+ 1 ,
46+ ).view (hidden_size // layer .group_size , new_kv_heads * head_dim )
47+ layer .out_features = layer .out_features * repeat
48+ else :
49+ layer .weight .data = torch .repeat_interleave (
50+ layer .weight .data .view (orig_kv_heads , head_dim , hidden_size ), repeat , 0
51+ ).view (new_kv_heads * head_dim , hidden_size )
1452
1553
1654def main (args ):
1755 # Load the model and tokenizer
1856 model_name = args .model_name
1957 model_base_name = model_name .split ("/" )[- 1 ]
58+ # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU.
59+ replace_transformers_quantizers ()
2060 model = AutoModelForCausalLM .from_pretrained (
21- model_name , # num_hidden_layers=2,
61+ model_name ,
62+ num_hidden_layers = 1 ,
2263 attn_implementation = "eager" ,
2364 )
24-
65+ # Undo the effect of replace_transformers_quantizers
66+ undo_transformers_quantizers ()
2567 tokenizer = AutoTokenizer .from_pretrained (model_name )
2668 inputs = tokenizer (args .prompt , return_tensors = "pt" )
2769
@@ -44,12 +86,8 @@ def main(args):
4486 attn = block .self_attn
4587 attn .num_key_value_heads = new_kv_heads
4688 attn .num_key_value_groups = block .self_attn .num_heads // new_kv_heads
47- attn .k_proj .weight .data = torch .repeat_interleave (
48- attn .k_proj .weight .data .view (orig_kv_heads , attn .head_dim , attn .hidden_size ), repeat , 0
49- ).view (new_kv_heads * attn .head_dim , attn .hidden_size )
50- attn .v_proj .weight .data = torch .repeat_interleave (
51- attn .v_proj .weight .data .view (orig_kv_heads , attn .head_dim , attn .hidden_size ), repeat , 0
52- ).view (new_kv_heads * attn .head_dim , attn .hidden_size )
89+ duplicate_weights_for_linear_layer (attn .k_proj , orig_kv_heads , repeat , attn .head_dim , attn .hidden_size )
90+ duplicate_weights_for_linear_layer (attn .v_proj , orig_kv_heads , repeat , attn .head_dim , attn .hidden_size )
5391
5492 # Generate modified outputs and tokens
5593 with torch .inference_mode ():
@@ -60,6 +98,11 @@ def main(args):
6098 print ("Original:" , tokenizer .batch_decode (orig_tokens ))
6199 print ("Modified:" , tokenizer .batch_decode (mod_tokens ))
62100
101+ if not torch .all (orig_tokens == mod_tokens ):
102+ raise RuntimeError (
103+ "Something went wrong while duplicating KV heads weights, output token don't match after modification"
104+ )
105+
63106 # Export the modified model
64107 q_model = QEFFAutoModelForCausalLM (model , model_name )
65108 export (
0 commit comments