Skip to content

Commit cb59a6e

Browse files
zengzengrancml
authored andcommitted
Repair shared expert dp
Signed-off-by: zengran <[email protected]>
1 parent 91e2820 commit cb59a6e

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __init__(self, vllm_config):
7070
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
7171
if self.enable_shared_expert_dp:
7272
from vllm_ascend.utils import enable_sp
73-
assert enable_sp(
74-
vllm_config), "shared_expert_dp requires enable_sp=True."
73+
assert enable_sp(vllm_config=vllm_config,
74+
enable_shared_expert_dp=True)
7575
self.multistream_overlap_shared_expert = additional_config.get(
7676
"multistream_overlap_shared_expert", False)
7777
self.recompute_scheduler_enable = additional_config.get(

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
if prefill_context_parallel_enable():
3939
from vllm.distributed import get_pcp_group
40+
if shared_expert_dp_enabled():
41+
from vllm.distributed import get_tensor_model_parallel_world_size
4042

4143
from vllm.utils.platform_utils import is_pin_memory_available
4244
from vllm.utils.torch_utils import set_default_torch_dtype
@@ -298,6 +300,10 @@ def dummy_run(self,
298300
self.model(input_ids=input_ids,
299301
positions=positions,
300302
hidden_states=previous_hidden_states)
303+
if self.enable_shared_expert_dp:
304+
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions,True)
305+
previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
306+
previous_hidden_states,True)
301307
forward_context = get_forward_context()
302308
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
303309
not forward_context.capturing:
@@ -690,6 +696,12 @@ def _propose(
690696
(self.num_speculative_tokens + 1))
691697
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
692698
uniform_decode=uniform_decode)
699+
# Enabling sp/shared_expert_dp will perform educe_scatter operation.
700+
if self.enable_shared_expert_dp:
701+
tp_world_size = get_tensor_model_parallel_world_size()
702+
reduce_num_input_tokens = num_input_tokens // tp_world_size
703+
batch_descriptor = BatchDescriptor(num_tokens=reduce_num_input_tokens,
704+
uniform_decode=uniform_decode)
693705
else:
694706
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
695707
uniform_decode=False)
@@ -741,12 +753,15 @@ def _propose(
741753
positions = torch.ops.vllm.maybe_pad_and_reduce(
742754
positions)
743755
positions = positions.squeeze(-1)
756+
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
757+
hidden_states)
744758

745759
hidden_states = self.model(input_ids=input_ids,
746760
positions=positions,
747761
hidden_states=hidden_states)
748762
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
749763
hidden_states.contiguous(), True)
764+
750765
forward_context = get_forward_context()
751766
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
752767
if self.vllm_config.model_config.use_mla:
@@ -821,20 +836,20 @@ def _propose(
821836
batch_size,
822837
attn_metadata_i.decode.actual_seq_lengths_q)
823838
attn_metadata_i.decode.cos = builder.cos_cache[
824-
positions].unsqueeze(1).unsqueeze(2)
839+
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
825840
attn_metadata_i.decode.sin = builder.sin_cache[
826-
positions].unsqueeze(1).unsqueeze(2)
841+
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
827842
# NOTE(woosuk): We should handle the case where the draft model
828843
# generates tokens beyond the max model length. Since it is complex
829844
# to remove such requests from the batch, we keep them in the batch
830845
# but adjust the position ids and slot mappings to avoid the
831846
# out-of-range access during the model execution. The draft tokens
832847
# generated with this adjustment should be ignored.
833-
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
848+
exceeds_max_model_len = positions[:batch_size] >= self.runner.model_config.max_model_len
834849
# Mask out the position ids that exceed the max model length.
835850
# Otherwise, we may get out-of-range error in RoPE.
836851
clamped_positions = torch.where(exceeds_max_model_len, 0,
837-
positions)
852+
positions[:batch_size])
838853
# Increment the sequence lengths.
839854
attn_metadata_i.seq_lens[:batch_size] += 1
840855
# For the requests that exceed the max model length, we set the

vllm_ascend/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,8 @@ def dense_optim_enable() -> bool:
751751
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
752752

753753

754-
def enable_sp(vllm_config=None) -> bool:
754+
def enable_sp(vllm_config = None,
755+
enable_shared_expert_dp: bool = False) -> bool:
755756
global _ENABLE_SP
756757
if _ENABLE_SP is None:
757758
if vllm_config is None:
@@ -765,6 +766,12 @@ def enable_sp(vllm_config=None) -> bool:
765766
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
766767
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
767768

769+
if not _ENABLE_SP and enable_shared_expert_dp:
770+
_ENABLE_SP = True
771+
logger.info(
772+
f"shared_expert_dp requires enable_sp = True. has set enable_sp to True"
773+
)
774+
768775
if not _ENABLE_SP:
769776
return _ENABLE_SP
770777

0 commit comments

Comments
 (0)