Skip to content

Commit 14049e7

Browse files
[Feat] shared expert dp for deepseek_mtp and e2e
Signed-off-by: chenmenglong <[email protected]>
1 parent 3ac76fd commit 14049e7

File tree

8 files changed

+154
-9
lines changed

8 files changed

+154
-9
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
3+
import pytest
4+
from vllm import SamplingParams
5+
6+
from tests.e2e.conftest import VllmRunner
7+
from tests.e2e.model_utils import check_outputs_equal
8+
9+
MODELS = [
10+
"vllm-ascend/DeepSeek-V2-Lite",
11+
]
12+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
13+
14+
15+
@pytest.mark.parametrize("model", MODELS)
16+
def test_models_with_enable_shared_expert_dp(model: str) -> None:
17+
18+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
19+
del os.environ['HCCL_OP_EXPANSION_MODE']
20+
21+
prompts = [
22+
"Hello, my name is", "The capital of the United States is",
23+
"The capital of France is", "The future of AI is"
24+
]
25+
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
26+
27+
with VllmRunner(
28+
model,
29+
max_model_len=1024,
30+
enforce_eager=True,
31+
tensor_parallel_size=2,
32+
enable_expert_parallel=True,
33+
) as runner:
34+
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
35+
36+
os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM1"] = "1"
37+
with VllmRunner(
38+
model,
39+
max_model_len=1024,
40+
enforce_eager=True,
41+
tensor_parallel_size=2,
42+
enable_expert_parallel=True,
43+
additional_config={
44+
"enable_shared_expert_dp": True,
45+
},
46+
) as runner:
47+
shared_expert_dp_eager_outputs = runner.model.generate(
48+
prompts, sampling_params)
49+
50+
with VllmRunner(
51+
model,
52+
max_model_len=1024,
53+
tensor_parallel_size=2,
54+
enforce_eager=False,
55+
compilation_config={
56+
"cudagraph_capture_sizes": [1, 4, 8, 16],
57+
"cudagraph_mode": "FULL_DECODE_ONLY",
58+
},
59+
additional_config={
60+
"enable_shared_expert_dp": True,
61+
},
62+
) as runner:
63+
shared_expert_dp_aclgraph_outputs = runner.model.generate(
64+
prompts, sampling_params)
65+
66+
vllm_eager_outputs_list = []
67+
for output in vllm_eager_outputs:
68+
vllm_eager_outputs_list.append(
69+
(output.outputs[0].index, output.outputs[0].text))
70+
71+
shared_expert_dp_eager_outputs_list = []
72+
for output in shared_expert_dp_eager_outputs:
73+
shared_expert_dp_eager_outputs_list.append(
74+
(output.outputs[0].index, output.outputs[0].text))
75+
76+
shared_expert_dp_aclgraph_outputs_list = []
77+
for output in shared_expert_dp_aclgraph_outputs:
78+
shared_expert_dp_aclgraph_outputs_list.append(
79+
(output.outputs[0].index, output.outputs[0].text))
80+
81+
check_outputs_equal(
82+
outputs_0_lst=vllm_eager_outputs_list,
83+
outputs_1_lst=shared_expert_dp_eager_outputs_list,
84+
name_0="vllm_eager_outputs",
85+
name_1="shared_expert_dp_eager_outputs",
86+
)
87+
88+
check_outputs_equal(
89+
outputs_0_lst=vllm_eager_outputs_list,
90+
outputs_1_lst=shared_expert_dp_aclgraph_outputs_list,
91+
name_0="vllm_eager_outputs",
92+
name_1="shared_expert_dp_aclgraph_outputs",
93+
)

tests/ut/ops/test_layernorm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from unittest.mock import patch
23

34
import pytest
45
import torch
@@ -41,7 +42,9 @@ def context(self, mocker: MockerFixture):
4142
# Test case for the most common and basic scenario
4243
@pytest.mark.parametrize(
4344
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
44-
def test_forward_oot_basic(self, residual):
45+
@patch("torch.ops.vllm.maybe_chunk_residual")
46+
def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual):
47+
mock_maybe_chunk_residual.side_effect = lambda x, residual: residual
4548
layer = RMSNorm(hidden_size=8, eps=1e-05)
4649
x = torch.randn(4, 8, dtype=torch.float16)
4750
if residual is not None:
@@ -105,6 +108,8 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
105108
mock_forward_context.num_hidden_layers = num_hidden_layers
106109
mock_forward_context.fusion_linear = "gate_up_dense"
107110
mock_forward_context.weight_prefetch_method = None
111+
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
112+
lambda x, residual: residual)
108113

109114
# Ensure fusion and layer_idx increment are handled correctly
110115
x = torch.randn(4, 8, dtype=torch.float16)

vllm_ascend/ascend_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def __init__(self, vllm_config):
6868
self.enable_shared_expert_dp = additional_config.get(
6969
"enable_shared_expert_dp", False
7070
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
71+
if self.enable_shared_expert_dp:
72+
from vllm_ascend.utils import enable_sp
73+
assert enable_sp(
74+
vllm_config), "shared_expert_dp requires enable_sp=True."
7175
self.multistream_overlap_shared_expert = additional_config.get(
7276
"multistream_overlap_shared_expert", False)
7377
self.recompute_scheduler_enable = additional_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,8 @@ def forward(
14561456
forward_context = get_forward_context()
14571457
if (self.enable_mlapo and
14581458
(attn_metadata is None or not forward_context.with_prefill)):
1459+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
1460+
hidden_states.contiguous(), need_gather_q_kv)
14591461
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
14601462
hidden_states, kv_cache, attn_metadata)
14611463
else:

vllm_ascend/ops/layernorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def forward_oot(
109109
import torch_npu
110110

111111
if residual is not None:
112+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
112113
assert x.size(0) == residual.size(0)
113114
x, residual = _addrmsnorm_forward_oot(
114115
self, x, residual, self.next_need_quant_fusion_linear,

vllm_ascend/ops/register_custom_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn.functional as F
33
import torch_npu
44
from vllm.distributed import (get_dp_group, get_ep_group,
5+
get_tensor_model_parallel_rank,
56
get_tensor_model_parallel_world_size,
67
tensor_model_parallel_all_gather,
78
tensor_model_parallel_all_reduce,
@@ -20,6 +21,27 @@
2021
from vllm.utils.torch_utils import direct_register_custom_op
2122

2223

24+
def _maybe_chunk_residual_impl(x: torch.Tensor,
25+
residual: torch.Tensor) -> torch.Tensor:
26+
try:
27+
forward_context = get_forward_context()
28+
except AssertionError:
29+
return residual
30+
31+
if x.size(0) != residual.size(0):
32+
sp_enabled = forward_context.sp_enabled
33+
assert sp_enabled is True, ("Currently, this situation only occurs "
34+
"when sp is enabled")
35+
pad_size = forward_context.pad_size
36+
if pad_size > 0:
37+
residual = F.pad(residual, (0, 0, 0, pad_size))
38+
tp_size = get_tensor_model_parallel_world_size()
39+
tp_rank = get_tensor_model_parallel_rank()
40+
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
41+
42+
return residual
43+
44+
2345
def _maybe_all_gather_and_maybe_unpad_impl(
2446
x: torch.Tensor,
2547
label: bool,
@@ -264,6 +286,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
264286
return output
265287

266288

289+
direct_register_custom_op(op_name="maybe_chunk_residual",
290+
op_func=_maybe_chunk_residual_impl,
291+
fake_impl=lambda x, residual: x,
292+
mutates_args=[],
293+
dispatch_key="PrivateUse1")
294+
267295
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
268296
op_func=_maybe_all_gather_and_maybe_unpad_impl,
269297
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,

vllm_ascend/platform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
364364
if parallel_config and parallel_config.worker_cls == "auto":
365365
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
366366
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
367-
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
367+
if ascend_config.torchair_graph_config.enabled:
368368
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
369369
else:
370370
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
@@ -435,8 +435,6 @@ def get_attn_backend_cls(
435435
ascend_config = get_ascend_config()
436436

437437
if use_mla and ascend_config.enable_shared_expert_dp:
438-
if use_mla and not use_sparse:
439-
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
440438
if use_mla and use_sparse:
441439
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
442440

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ def dummy_run(self,
211211
num_actual_tokens=0,
212212
aclgraph_runtime_mode=aclgraph_runtime_mode,
213213
batch_descriptor=batch_descriptor):
214+
positions = positions.unsqueeze(-1)
215+
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
216+
positions = positions.squeeze(-1)
217+
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
218+
previous_hidden_states)
214219
self.model(input_ids=input_ids,
215220
positions=positions,
216221
hidden_states=previous_hidden_states)
@@ -528,11 +533,20 @@ def _propose(
528533
with ProfileExecuteDuration().capture_async('mtp_forward'):
529534
model_kwargs = {}
530535
model_kwargs["attn_metadata"] = attn_metadata
531-
532-
hidden_states = self.model(
533-
input_ids=self.input_ids[:num_input_tokens],
534-
positions=self.positions[:num_input_tokens],
535-
hidden_states=self.hidden_states[:num_input_tokens])
536+
input_ids = self.input_ids[:num_input_tokens]
537+
positions = self.positions[:num_input_tokens]
538+
hidden_states = self.hidden_states[:num_input_tokens]
539+
540+
# positions [N] -> [N, 1] for padding
541+
positions = positions.unsqueeze(-1)
542+
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
543+
positions = positions.squeeze(-1)
544+
545+
hidden_states = self.model(input_ids=input_ids,
546+
positions=positions,
547+
hidden_states=hidden_states)
548+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
549+
hidden_states.contiguous(), True)
536550

537551
num_indices = last_token_indices.shape[0]
538552
if lmhead_tp_enable():

0 commit comments

Comments
 (0)