We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 49929cd commit 438a22fCopy full SHA for 438a22f
lmdeploy/pytorch/nn/rotary_embedding.py
@@ -261,7 +261,8 @@ def update_num_kv_heads(num_key_value_heads: int):
261
from lmdeploy.pytorch.distributed import get_dist_manager
262
dist_mgr = get_dist_manager()
263
dist_ctx = dist_mgr.current_context()
264
- tp = dist_ctx.dist_config.attn_config.tp
+ tp = dist_ctx.dist_config.attn_tp
265
+ # tp = dist_ctx.dist_config.attn_config.tp
266
if tp > 1:
267
num_key_value_heads = max(1, num_key_value_heads // tp)
268
return num_key_value_heads, tp
0 commit comments