-
Notifications
You must be signed in to change notification settings - Fork 632
support async mtp #4511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+477
−83
Merged
support async mtp #4511
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
6201bfd
implement async scheduling for mtp
Ronald1995 ba323da
fix synchronize error
Ronald1995 01ae70c
fix indent error
Ronald1995 add7852
fix synchronize
Ronald1995 03755ef
fix synchronize error of repeat_interleave
Ronald1995 16eb688
fix synchronize error
Ronald1995 27bc0f9
fix synchronize error in _calc_spec_decode_metadata
Ronald1995 6d70c76
delete v2
Ronald1995 935e0d7
fix sync error of seq_lens tolist
Ronald1995 33c1c56
set pin_memory=True
Ronald1995 430d371
disable mtp graph when use async scheduling
Ronald1995 c0317c9
fix pin_memory error
Ronald1995 d9a1b9c
fix yapf error
Ronald1995 fb77399
revert rejection_sampler
Ronald1995 18cd2f8
fix yapf error
Ronald1995 d989c43
fix ut pin_memory error
Ronald1995 ecc7c64
handle kv cache
Ronald1995 161c500
fix hang
Ronald1995 d1e2d13
fix prev_sampled_token_ids wrong position
Ronald1995 8a6d9b6
fix yapf error
Ronald1995 1d94556
_sync_metadata_across_dp with cpu group
Ronald1995 a0da596
add e2e test for async_scheduling
Ronald1995 8461e6c
merge main
Ronald1995 682ab95
fix assert error of sampled_token_ids shape
Ronald1995 a59e660
fix RejectionSampler.parse_output
Ronald1995 9dc6e45
fix ut of test_async_scheduling
Ronald1995 886dad8
fix yapf error
Ronald1995 01327de
implement out of place Increment of seq_lens_cpu
Ronald1995 b811e45
fix ruff error
Ronald1995 b381f37
make seq_lens increment out of place
Ronald1995 b7fbfc7
Merge branch 'main' into async_mtp3
Ronald1995 c58a615
Merge branch 'main' into async_mtp3
Ronald1995 0d030ab
Merge branch 'main' into async_mtp3
Ronald1995 0c009dc
Merge branch 'main' into async_mtp3
Ronald1995 dea5213
Merge branch 'main' into async_mtp3
wangxiyuan 7eb30c7
Merge branch 'main' into async_mtp3
wangxiyuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import os | ||
| from itertools import repeat | ||
| from typing import Any | ||
|
|
||
| import pytest | ||
| import torch._dynamo.config as dynamo_config | ||
| from vllm import SamplingParams | ||
| from vllm.v1.metrics.reader import Metric | ||
|
|
||
| from tests.e2e.conftest import VllmRunner | ||
| from tests.e2e.model_utils import check_outputs_equal | ||
|
|
||
| MODEL = "Qwen/Qwen3-0.6B" | ||
|
|
||
| first_prompt = ("The following numbers of the sequence " + | ||
| ", ".join(str(i) for i in range(10)) + " are:") | ||
| example_prompts = [first_prompt, "In one word, the capital of France is " | ||
| ] + [f"Tell me about the number {i}: " for i in range(32)] | ||
|
|
||
| default_params = dict( | ||
| temperature=0.0, # greedy | ||
| max_tokens=23, | ||
| min_tokens=18, | ||
| ) | ||
|
|
||
|
|
||
| def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): | ||
| """Test consistency of combos of async scheduling, preemption, | ||
| uni/multiproc executor, prefill chunking.""" | ||
| test_sampling_params: list[dict[str, Any]] = [ | ||
| dict(), | ||
| ] | ||
|
|
||
| # test_preemption, executor, async_scheduling, | ||
| # spec_config, test_prefill_chunking | ||
| test_configs = [ | ||
| (False, "mp", False, None, False), | ||
| (False, "mp", True, None, False), | ||
| (False, "uni", True, None, False), | ||
| ] | ||
|
|
||
| run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) | ||
|
|
||
|
|
||
| @dynamo_config.patch(cache_size_limit=16) | ||
| def run_tests( | ||
| monkeypatch: pytest.MonkeyPatch, | ||
| model: str, | ||
| test_configs: list[tuple], | ||
| test_sampling_params: list[dict[str, Any]], | ||
| ): | ||
| """Test consistency of combos of async scheduling, preemption, | ||
| uni/multiproc executor with spec decoding.""" | ||
|
|
||
| with monkeypatch.context(): | ||
| # avoid precision errors | ||
| outputs: list[tuple[str, list, list]] = [] | ||
| for n, ( | ||
| test_preemption, | ||
| executor, | ||
| async_scheduling, | ||
| spec_config, | ||
| test_prefill_chunking, | ||
| ) in enumerate(test_configs, 1): | ||
| test_str = f"{n}/{len(test_configs)}" | ||
| test_results = run_test( | ||
| model, | ||
| test_str, | ||
| test_sampling_params, | ||
| test_preemption, | ||
| executor, | ||
| async_scheduling, | ||
| spec_config, | ||
| test_prefill_chunking=test_prefill_chunking, | ||
| ) | ||
| outputs.append(test_results) | ||
|
|
||
| baseline_config, baseline_tests, _ = outputs[0] | ||
| _, _, baseline_acceptances = next((o for o in outputs if o[2] is not None), | ||
| (None, None, None)) | ||
|
|
||
| print( | ||
| f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}" | ||
| ) | ||
|
|
||
| failure = None | ||
| for test_config, test_outputs, test_acceptance_rates in outputs[1:]: | ||
| for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip( | ||
| baseline_tests, | ||
| baseline_acceptances or repeat(None), | ||
| test_outputs, | ||
| test_acceptance_rates or repeat(None), | ||
| test_sampling_params, | ||
| ): | ||
| try: | ||
| check_outputs_equal( | ||
| outputs_0_lst=base_outs, | ||
| outputs_1_lst=test_outs, | ||
| name_0=f"baseline=[{baseline_config}], params={params}", | ||
| name_1=f"config=[{test_config}], params={params}", | ||
| ) | ||
|
|
||
| if (base_acceptance_rate is not None | ||
| and test_acceptance_rate is not None): | ||
| if "spec_mml=None" in test_config: | ||
| assert (test_acceptance_rate > base_acceptance_rate | ||
| or test_acceptance_rate == pytest.approx( | ||
| base_acceptance_rate, rel=5e-2)) | ||
| else: | ||
| # Currently the reported acceptance rate is expected to be | ||
| # lower when we sometimes skip drafting altogether. | ||
| assert test_acceptance_rate > 0.1 | ||
| print(f"PASSED: config=[{test_config}], params={params}" | ||
| f" accept_rate={test_acceptance_rate}") | ||
| except AssertionError as e: | ||
| print(f"FAILED: config=[{test_config}], params={params}" | ||
| f" accept_rate={test_acceptance_rate}") | ||
| if failure is None: | ||
| failure = e | ||
|
|
||
| if failure is not None: | ||
| raise failure | ||
|
|
||
|
|
||
| def run_test( | ||
| model: str, | ||
| test_str: str, | ||
| sampling_param_tests: list[dict[str, Any]], | ||
| test_preemption: bool, | ||
| executor: str, | ||
| async_scheduling: bool, | ||
| spec_config: dict[str, Any] | None, | ||
| test_prefill_chunking: bool, | ||
| ): | ||
| os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' | ||
| spec_decoding = spec_config is not None | ||
| cache_arg: dict[str, Any] = ( | ||
| # Force preemptions | ||
| dict(num_gpu_blocks_override=2) if test_preemption else dict( | ||
| gpu_memory_utilization=0.9)) | ||
| spec_mml = (spec_config or {}).get("max_model_len") | ||
| test_config = (f"executor={executor}, preemption={test_preemption}, " | ||
| f"async_sched={async_scheduling}, " | ||
| f"chunk_prefill={test_prefill_chunking}, " | ||
| f"spec_decoding={spec_decoding}, spec_mml={spec_mml}") | ||
| print("-" * 80) | ||
| print(f"---- TESTING {test_str}: {test_config}") | ||
| print("-" * 80) | ||
| with VllmRunner( | ||
| model, | ||
| max_model_len=512, | ||
| enable_chunked_prefill=test_prefill_chunking, | ||
| # Force prefill chunking | ||
| max_num_batched_tokens=48 if test_prefill_chunking else None, | ||
| enforce_eager=True, | ||
| async_scheduling=async_scheduling, | ||
| distributed_executor_backend=executor, | ||
| dtype="float16", # avoid precision errors | ||
| speculative_config=spec_config, | ||
| disable_log_stats=False, | ||
| **cache_arg, | ||
| ) as vllm_model: | ||
| results = [] | ||
| acceptance_rates: list[float] | None = [] if spec_decoding else None | ||
| for override_params in sampling_param_tests: | ||
| metrics_before = vllm_model.model.get_metrics() | ||
| print(f"----------- RUNNING PARAMS: {override_params}") | ||
| results.append( | ||
| vllm_model.generate( | ||
| example_prompts, | ||
| sampling_params=SamplingParams(**default_params, | ||
| **override_params), | ||
| )) | ||
| metrics_after = vllm_model.model.get_metrics() | ||
| if acceptance_rates is not None: | ||
| acceptance_rate = _get_acceptance_rate(metrics_before, | ||
| metrics_after) | ||
| acceptance_rates.append(acceptance_rate) | ||
| print(f"ACCEPTANCE RATE {acceptance_rate}") | ||
|
|
||
| if test_preemption: | ||
| preemptions = _get_count(metrics_before, metrics_after, | ||
| "vllm:num_preemptions") | ||
| assert preemptions > 0, "preemption test had no preemptions" | ||
|
|
||
| if len(results) > 1: | ||
| # First check that the different parameter configs | ||
| # actually result in different output. | ||
| for other_test_outs, params in zip(results[1:], | ||
| sampling_param_tests[1:]): | ||
| with pytest.raises(AssertionError): | ||
| check_outputs_equal( | ||
| outputs_0_lst=results[0][0], | ||
| outputs_1_lst=other_test_outs, | ||
| name_0=f"baseline params={params}", | ||
| name_1=f"other params={params}", | ||
| ) | ||
|
|
||
| return test_config, results, acceptance_rates | ||
|
|
||
|
|
||
| def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float: | ||
| draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens") | ||
| accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens") | ||
| return accept / draft if draft > 0 else 0.0 | ||
|
|
||
|
|
||
| def _get_count(before: list[Metric], after: list[Metric], name: str) -> int: | ||
| before_val = next(m.value for m in before if m.name == name) | ||
| after_val = next(m.value for m in after if m.name == name) | ||
| return after_val - before_val |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.