Skip to content

Commit 302ef40

Browse files
authored
[DSA][MLA] Tiny refactor on DeepSeek to make it reusable for different backends (vllm-project#26656)
Signed-off-by: MengqingCao <[email protected]>
1 parent 8865da1 commit 302ef40

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

vllm/attention/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def __init__(
587587
prefix: str = "",
588588
use_sparse: bool = False,
589589
indexer: object | None = None,
590+
**extra_impl_args,
590591
):
591592
super().__init__()
592593
self.num_heads = num_heads
@@ -639,6 +640,7 @@ def __init__(
639640
v_head_dim=self.v_head_dim,
640641
kv_b_proj=kv_b_proj,
641642
indexer=indexer,
643+
**extra_impl_args,
642644
)
643645

644646
self.use_direct_call = not current_platform.opaque_attention_op()

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
VocabParallelEmbedding,
1818
)
1919
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20+
from vllm.platforms import current_platform
2021
from vllm.sequence import IntermediateTensors
2122

22-
from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name
23+
from .deepseek_v2 import (
24+
DeepseekV2DecoderLayer,
25+
get_spec_layer_idx_from_weight_name,
26+
)
2327
from .interfaces import SupportsPP
2428
from .utils import maybe_prefix
2529

@@ -56,14 +60,16 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
5660
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5761
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
5862

63+
self.device = current_platform.device_type
64+
5965
self.is_v32 = hasattr(config, "index_topk")
6066
if self.is_v32:
6167
topk_tokens = config.index_topk
6268
topk_indices_buffer = torch.empty(
6369
vllm_config.scheduler_config.max_num_batched_tokens,
6470
topk_tokens,
6571
dtype=torch.int32,
66-
device="cuda",
72+
device=self.device,
6773
)
6874
else:
6975
topk_indices_buffer = None

vllm/model_executor/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11651165
config = vllm_config.model_config.hf_config
11661166
quant_config = vllm_config.quant_config
11671167
self.config = config
1168+
self.device = current_platform.device_type
11681169

11691170
self.vocab_size = config.vocab_size
11701171
self.is_v32 = hasattr(config, "index_topk")
@@ -1174,7 +1175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11741175
vllm_config.scheduler_config.max_num_batched_tokens,
11751176
topk_tokens,
11761177
dtype=torch.int32,
1177-
device="cuda",
1178+
device=self.device,
11781179
)
11791180
else:
11801181
topk_indices_buffer = None

0 commit comments

Comments
 (0)