Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6201bfd
implement async scheduling for mtp
Ronald1995 Nov 11, 2025
ba323da
fix synchronize error
Ronald1995 Nov 27, 2025
01ae70c
fix indent error
Ronald1995 Nov 27, 2025
add7852
fix synchronize
Ronald1995 Nov 27, 2025
03755ef
fix synchronize error of repeat_interleave
Ronald1995 Nov 27, 2025
16eb688
fix synchronize error
Ronald1995 Nov 27, 2025
27bc0f9
fix synchronize error in _calc_spec_decode_metadata
Ronald1995 Nov 27, 2025
6d70c76
delete v2
Ronald1995 Nov 27, 2025
935e0d7
fix sync error of seq_lens tolist
Ronald1995 Nov 28, 2025
33c1c56
set pin_memory=True
Ronald1995 Nov 28, 2025
430d371
disable mtp graph when use async scheduling
Ronald1995 Nov 28, 2025
c0317c9
fix pin_memory error
Ronald1995 Nov 28, 2025
d9a1b9c
fix yapf error
Ronald1995 Nov 28, 2025
fb77399
revert rejection_sampler
Ronald1995 Nov 28, 2025
18cd2f8
fix yapf error
Ronald1995 Nov 28, 2025
d989c43
fix ut pin_memory error
Ronald1995 Nov 29, 2025
ecc7c64
handle kv cache
Ronald1995 Nov 29, 2025
161c500
fix hang
Ronald1995 Nov 29, 2025
d1e2d13
fix prev_sampled_token_ids wrong position
Ronald1995 Nov 30, 2025
8a6d9b6
fix yapf error
Ronald1995 Dec 1, 2025
1d94556
_sync_metadata_across_dp with cpu group
Ronald1995 Dec 2, 2025
a0da596
add e2e test for async_scheduling
Ronald1995 Dec 3, 2025
8461e6c
merge main
Ronald1995 Dec 3, 2025
682ab95
fix assert error of sampled_token_ids shape
Ronald1995 Dec 3, 2025
a59e660
fix RejectionSampler.parse_output
Ronald1995 Dec 3, 2025
9dc6e45
fix ut of test_async_scheduling
Ronald1995 Dec 4, 2025
886dad8
fix yapf error
Ronald1995 Dec 4, 2025
01327de
implement out of place Increment of seq_lens_cpu
Ronald1995 Dec 4, 2025
b811e45
fix ruff error
Ronald1995 Dec 5, 2025
b381f37
make seq_lens increment out of place
Ronald1995 Dec 5, 2025
b7fbfc7
Merge branch 'main' into async_mtp3
Ronald1995 Dec 5, 2025
c58a615
Merge branch 'main' into async_mtp3
Ronald1995 Dec 5, 2025
0d030ab
Merge branch 'main' into async_mtp3
Ronald1995 Dec 5, 2025
0c009dc
Merge branch 'main' into async_mtp3
Ronald1995 Dec 5, 2025
dea5213
Merge branch 'main' into async_mtp3
wangxiyuan Dec 6, 2025
7eb30c7
Merge branch 'main' into async_mtp3
wangxiyuan Dec 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions tests/e2e/singlecard/test_async_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from itertools import repeat
from typing import Any

import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.v1.metrics.reader import Metric

from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal

MODEL = "Qwen/Qwen3-0.6B"

first_prompt = ("The following numbers of the sequence " +
", ".join(str(i) for i in range(10)) + " are:")
example_prompts = [first_prompt, "In one word, the capital of France is "
] + [f"Tell me about the number {i}: " for i in range(32)]

default_params = dict(
temperature=0.0, # greedy
max_tokens=23,
min_tokens=18,
)


def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, prefill chunking."""
test_sampling_params: list[dict[str, Any]] = [
dict(),
]

# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", True, None, False),
(False, "uni", True, None, False),
]

run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)


@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""

with monkeypatch.context():
# avoid precision errors
outputs: list[tuple[str, list, list]] = []
for n, (
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking,
) in enumerate(test_configs, 1):
test_str = f"{n}/{len(test_configs)}"
test_results = run_test(
model,
test_str,
test_sampling_params,
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
)
outputs.append(test_results)

baseline_config, baseline_tests, _ = outputs[0]
_, _, baseline_acceptances = next((o for o in outputs if o[2] is not None),
(None, None, None))

print(
f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}"
)

failure = None
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip(
baseline_tests,
baseline_acceptances or repeat(None),
test_outputs,
test_acceptance_rates or repeat(None),
test_sampling_params,
):
try:
check_outputs_equal(
outputs_0_lst=base_outs,
outputs_1_lst=test_outs,
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)

if (base_acceptance_rate is not None
and test_acceptance_rate is not None):
if "spec_mml=None" in test_config:
assert (test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate == pytest.approx(
base_acceptance_rate, rel=5e-2))
else:
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
print(f"PASSED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}")
except AssertionError as e:
print(f"FAILED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}")
if failure is None:
failure = e

if failure is not None:
raise failure


def run_test(
model: str,
test_str: str,
sampling_param_tests: list[dict[str, Any]],
test_preemption: bool,
executor: str,
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
):
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=2) if test_preemption else dict(
gpu_memory_utilization=0.9))
spec_mml = (spec_config or {}).get("max_model_len")
test_config = (f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}")
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
with VllmRunner(
model,
max_model_len=512,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float16", # avoid precision errors
speculative_config=spec_config,
disable_log_stats=False,
**cache_arg,
) as vllm_model:
results = []
acceptance_rates: list[float] | None = [] if spec_decoding else None
for override_params in sampling_param_tests:
metrics_before = vllm_model.model.get_metrics()
print(f"----------- RUNNING PARAMS: {override_params}")
results.append(
vllm_model.generate(
example_prompts,
sampling_params=SamplingParams(**default_params,
**override_params),
))
metrics_after = vllm_model.model.get_metrics()
if acceptance_rates is not None:
acceptance_rate = _get_acceptance_rate(metrics_before,
metrics_after)
acceptance_rates.append(acceptance_rate)
print(f"ACCEPTANCE RATE {acceptance_rate}")

if test_preemption:
preemptions = _get_count(metrics_before, metrics_after,
"vllm:num_preemptions")
assert preemptions > 0, "preemption test had no preemptions"

if len(results) > 1:
# First check that the different parameter configs
# actually result in different output.
for other_test_outs, params in zip(results[1:],
sampling_param_tests[1:]):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0][0],
outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)

return test_config, results, acceptance_rates


def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
return accept / draft if draft > 0 else 0.0


def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
before_val = next(m.value for m in before if m.name == name)
after_val = next(m.value for m in after if m.name == name)
return after_val - before_val
1 change: 1 addition & 0 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
self.mock_device = 'cpu:0'
torch.Tensor.pin_memory = lambda x: x # noqa
self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config,
self.mock_device)
Expand Down
8 changes: 8 additions & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def test_ascend_mla_metadata_builder_build_full_graph(
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
torch.Tensor.pin_memory = lambda x: x # noqa

mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
Expand Down Expand Up @@ -534,6 +535,7 @@ def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
Expand Down Expand Up @@ -599,6 +601,7 @@ def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
Expand Down Expand Up @@ -660,6 +663,8 @@ def test_build_decode_only_metadata(self, mock_get_ascend_config,
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa

pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
Expand Down Expand Up @@ -713,6 +718,8 @@ def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa

pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
Expand Down Expand Up @@ -767,6 +774,7 @@ def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def build(
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
])

query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
Expand Down
51 changes: 30 additions & 21 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,35 +556,43 @@ def build(
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=local_chunk_starts.to(device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(
dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
device, non_blocking=True
),
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
.tolist(),
local_context_lens_allranks=local_context_lens_allranks
.tolist(),
padded_local_cu_seq_lens=
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = \
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(
dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
))
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
Expand Down Expand Up @@ -616,7 +624,8 @@ def build(
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
Expand Down
Loading
Loading