Skip to content

Commit e20999f

Browse files
authored
replicate kv for some models when tp is divisble by kv_head_num (#2874)
* replicate kv for some models when tp is divisble by kv_head_num * export * update * update
1 parent 7deb69c commit e20999f

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

lmdeploy/turbomind/deploy/module.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel):
191191
self.attn_bias = model.model_config.attn_bias
192192

193193
def _reorder_and_merge(self, qkvo):
194-
q, k, v, o = map(transpose, qkvo)
194+
q, k, v, o = qkvo
195195
# reorder output dim for tm's rotary embedding layout
196196
if self.model.permute_qk:
197197
q = permute_v2(q, self.head_dim)
@@ -202,13 +202,37 @@ def _reorder_and_merge(self, qkvo):
202202
o = torch.zeros_like(q)
203203
return qkv, o
204204

205+
def _repeat_kv(self, qkvo, kind: str):
206+
"""replicate kv."""
207+
q, k, v, o = qkvo
208+
head_dim = self.model.model_config.size_per_head
209+
hidden_dim = self.model.model_config.hidden_units
210+
211+
def _repeat(x):
212+
dim = hidden_dim if kind != 'bias' else 1
213+
x = x.reshape(dim, -1, head_dim)
214+
x = x.repeat(1, 1, self.model.repeat_kv)
215+
x = x.reshape(dim, -1)
216+
return x
217+
218+
k, v = map(_repeat, (k, v))
219+
if kind == 'bias':
220+
if o is None:
221+
o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
222+
q, k, v, o = map(torch.squeeze, (q, k, v, o))
223+
224+
return (q, k, v, o)
225+
205226
def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
206227
if all(x is None for x in qkvo):
207228
return
208229
is_lora_a, is_lora_b = get_lora_flags(kind)
209230
if is_lora_a:
210231
qkv, o = map(transpose, qkvo)
211232
else:
233+
qkvo = tuple(map(transpose, qkvo))
234+
if self.model.repeat_kv:
235+
qkvo = self._repeat_kv(qkvo, kind)
212236
qkv, o = self._reorder_and_merge(qkvo)
213237
self.model.save_split(pack_fn(qkv),
214238
self._attn.format(idx, 'w_qkv', kind),

lmdeploy/turbomind/deploy/target_model/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ def __init__(self,
7878
self.model_config.expert_inter_size = _pad_inter_size(
7979
self.model_config.expert_inter_size,
8080
self.model_config.group_size, self.tensor_para_size)
81+
82+
# head_num is divisble by tp but kv_head_num is not
83+
# and tp is divisble by kv_head_num
84+
assert self.model_config.head_num % self.tensor_para_size == 0
85+
self.repeat_kv = 0
86+
if (self.tensor_para_size > self.model_config.kv_head_num and
87+
self.tensor_para_size % self.model_config.kv_head_num == 0):
88+
self.repeat_kv = (self.tensor_para_size //
89+
self.model_config.kv_head_num)
90+
self.model_config.kv_head_num = self.tensor_para_size
91+
8192
self.model_config.verify()
8293
assert self.model_config.kv_head_num % self.tensor_para_size == 0
8394

0 commit comments

Comments
 (0)