@@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel):
191191 self .attn_bias = model .model_config .attn_bias
192192
193193 def _reorder_and_merge (self , qkvo ):
194- q , k , v , o = map ( transpose , qkvo )
194+ q , k , v , o = qkvo
195195 # reorder output dim for tm's rotary embedding layout
196196 if self .model .permute_qk :
197197 q = permute_v2 (q , self .head_dim )
@@ -202,13 +202,37 @@ def _reorder_and_merge(self, qkvo):
202202 o = torch .zeros_like (q )
203203 return qkv , o
204204
205+ def _repeat_kv (self , qkvo , kind : str ):
206+ """replicate kv."""
207+ q , k , v , o = qkvo
208+ head_dim = self .model .model_config .size_per_head
209+ hidden_dim = self .model .model_config .hidden_units
210+
211+ def _repeat (x ):
212+ dim = hidden_dim if kind != 'bias' else 1
213+ x = x .reshape (dim , - 1 , head_dim )
214+ x = x .repeat (1 , 1 , self .model .repeat_kv )
215+ x = x .reshape (dim , - 1 )
216+ return x
217+
218+ k , v = map (_repeat , (k , v ))
219+ if kind == 'bias' :
220+ if o is None :
221+ o = torch .zeros (hidden_dim , dtype = q .dtype , device = q .device )
222+ q , k , v , o = map (torch .squeeze , (q , k , v , o ))
223+
224+ return (q , k , v , o )
225+
205226 def _export (self , idx : int , qkvo , kind : str , pack_fn , ** kwargs ):
206227 if all (x is None for x in qkvo ):
207228 return
208229 is_lora_a , is_lora_b = get_lora_flags (kind )
209230 if is_lora_a :
210231 qkv , o = map (transpose , qkvo )
211232 else :
233+ qkvo = tuple (map (transpose , qkvo ))
234+ if self .model .repeat_kv :
235+ qkvo = self ._repeat_kv (qkvo , kind )
212236 qkv , o = self ._reorder_and_merge (qkvo )
213237 self .model .save_split (pack_fn (qkv ),
214238 self ._attn .format (idx , 'w_qkv' , kind ),
0 commit comments