Skip to content

Commit 9d2d561

Browse files
zhyajiezhyajie
andauthored
[Bugfix] Fix precision corruption when shared_experts_stream=None (vllm-project#28942)
Signed-off-by: zhyajie <[email protected]> Co-authored-by: zhyajie <[email protected]>
1 parent fe69f33 commit 9d2d561

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ def __init__(
371371
logger.info_once("Disabling MoE shared_experts cuda stream")
372372
self.shared_experts_stream = None
373373
else:
374-
# TODO(rob): enable shared expert overlap with non-cuda.
375-
# aux_stream() returns None on non-cuda platforms.
374+
# TODO(rob): enable shared expert overlap with non-cuda-alike.
375+
# aux_stream() returns None on non-cuda-alike platforms.
376376
self.shared_experts_stream = aux_stream()
377377
if self.shared_experts_stream is not None:
378378
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
@@ -1865,6 +1865,11 @@ def forward_impl(
18651865
hidden_states_combined, router_logits = get_ep_group().dispatch(
18661866
hidden_states, router_logits, self.is_sequence_parallel
18671867
)
1868+
# Run shared experts before matrix multiply.
1869+
# because matrix multiply maybe modify the hidden_states.
1870+
if has_separate_shared_experts and not use_shared_experts_stream:
1871+
assert self.shared_experts is not None
1872+
shared_output = self.shared_experts(hidden_states)
18681873

18691874
# Matrix multiply.
18701875
final_hidden_states = self.quant_method.apply(
@@ -1908,8 +1913,6 @@ def forward_impl(
19081913
# conflict with the main stream
19091914
shared_output = self.shared_experts(hidden_states_clone)
19101915
current_stream().wait_stream(self.shared_experts_stream)
1911-
else:
1912-
shared_output = self.shared_experts(hidden_states)
19131916

19141917
final_hidden_states = (
19151918
shared_output,

vllm/utils/torch_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None:
426426

427427
from vllm.platforms import current_platform
428428

429-
# TODO: validate this works properly on ROCm platform.
430-
if _aux_stream is None and current_platform.is_cuda():
429+
if _aux_stream is None and current_platform.is_cuda_alike():
431430
_aux_stream = torch.cuda.Stream()
432431

433432
return _aux_stream

0 commit comments

Comments
 (0)