Skip to content

Commit adef903

Browse files
committed
fix
1 parent aea352c commit adef903

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

lmdeploy/pytorch/models/deepseek_v32.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,8 @@ def forward(
256256
):
257257
"""Rewrite of LlamaAttention.forward."""
258258
dist_ctx = get_dist_manager().current_context()
259-
if dist_ctx.dp > 1:
260-
num_heads = self.num_heads
261-
else:
262-
world_size = dist_ctx.world_size
263-
num_heads = self.num_heads // world_size
259+
tp_world_size = dist_ctx.dist_config.attn_tp
260+
num_heads = self.num_heads // tp_world_size
264261
nope_size = self.kv_lora_rank
265262
q_len = hidden_states.size(1)
266263

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ def __init__(self):
201201
def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
202202
"""forward."""
203203

204-
assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'
205204
assert cos.dim() <= 3 and sin.dim() <= 3
206205

207206
need_reshape = False
208207
if cos.dim() == 3:
209208
# for fope
209+
assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'
210210
need_reshape = True
211211
query_shape = query.shape
212212
key_shape = key.shape

0 commit comments

Comments
 (0)