Skip to content

Commit 46cbbca

Browse files
authored
[CI][DCP][Perf] reduce DCP CI execution time (#29858)
Signed-off-by: QiuChunshuo <[email protected]>
1 parent b286a31 commit 46cbbca

File tree

2 files changed

+100
-94
lines changed

2 files changed

+100
-94
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 95 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,35 @@
1616
import pytest
1717
import torch
1818

19+
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
20+
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
1921
from vllm.config.model import RunnerOption
2022
from vllm.logger import init_logger
2123

2224
from ..models.registry import HF_EXAMPLE_MODELS
23-
from ..utils import compare_two_settings, create_new_process_for_each_test
2425

2526
logger = init_logger("test_context_parallel")
2627

2728
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
2829

30+
CP_TEST_MODELS = [
31+
# TODO support other models
32+
# [LANGUAGE GENERATION]
33+
"deepseek-ai/DeepSeek-V2-Lite-Chat",
34+
"Qwen/Qwen2.5-1.5B-Instruct",
35+
]
36+
37+
# GSM8K eval configuration
38+
NUM_QUESTIONS = 256 # Fast eval for CI
39+
NUM_SHOTS = 5 # Few-shot examples
40+
# tp accuracy with 2% buffer
41+
MIN_ACCURACY = {
42+
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
43+
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
44+
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
45+
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
46+
}
47+
2948

3049
class ParallelSetup(NamedTuple):
3150
tp_size: int
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
3857

3958
class CPTestOptions(NamedTuple):
4059
multi_node_only: bool
41-
load_format: str | None = None
4260
attn_backend: str | None = None
4361

4462

@@ -54,17 +72,20 @@ def detailed(
5472
*,
5573
tp_base: int = 4,
5674
pp_base: int = 1,
57-
dcp_base: int = 1,
75+
dcp_multipliers: list[float] | None = None,
5876
cp_kv_cache_interleave_size: int = 1,
5977
multi_node_only: bool = False,
6078
runner: RunnerOption = "auto",
61-
load_format: str | None = None,
6279
attn_backend: str | None = None,
6380
):
6481
parallel_setups = []
82+
if dcp_multipliers is None:
83+
dcp_multipliers = [
84+
0.5,
85+
]
6586
for eager_mode_val in [False]:
6687
for pp_multiplier in [1]:
67-
for dcp_multiplier in [0.5, 1]:
88+
for dcp_multiplier in dcp_multipliers:
6889
for chunked_prefill_val in [True]:
6990
parallel_setups.append(
7091
ParallelSetup(
@@ -82,7 +103,6 @@ def detailed(
82103
runner=runner,
83104
test_options=CPTestOptions(
84105
multi_node_only=multi_node_only,
85-
load_format=load_format,
86106
attn_backend=attn_backend,
87107
),
88108
)
@@ -101,7 +121,24 @@ def iter_params(self, model_id: str):
101121
)
102122

103123

104-
def _compare_cp_with_tp(
124+
CP_TEXT_GENERATION_MODELS = {
125+
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
126+
CPTestSettings.detailed(
127+
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
128+
),
129+
],
130+
"Qwen/Qwen2.5-1.5B-Instruct": [
131+
CPTestSettings.detailed(
132+
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
133+
),
134+
CPTestSettings.detailed(
135+
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
136+
),
137+
],
138+
}
139+
140+
141+
def _test_cp_gsm8k(
105142
model_id: str,
106143
parallel_setup: ParallelSetup,
107144
distributed_backend: str,
@@ -121,7 +158,7 @@ def _compare_cp_with_tp(
121158
chunked_prefill,
122159
) = parallel_setup
123160

124-
multi_node_only, load_format, attn_backend = test_options
161+
multi_node_only, attn_backend = test_options
125162

126163
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
127164
model_info.check_transformers_version(on_fail="skip")
@@ -130,22 +167,7 @@ def _compare_cp_with_tp(
130167
tokenizer_mode = model_info.tokenizer_mode
131168
hf_overrides = model_info.hf_overrides
132169

133-
if load_format == "dummy":
134-
# Avoid OOM
135-
text_overrides = {
136-
"num_hidden_layers": 4,
137-
"hidden_size": 512,
138-
"intermediate_size": 800,
139-
"num_attention_heads": 4,
140-
"num_key_value_heads": 1,
141-
}
142-
143-
if is_multimodal:
144-
hf_overrides.update({"text_config": text_overrides})
145-
else:
146-
hf_overrides.update(text_overrides)
147-
else:
148-
model_info.check_available_online(on_fail="skip")
170+
model_info.check_available_online(on_fail="skip")
149171

150172
if num_gpus_available < tp_size * pp_size:
151173
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
@@ -157,90 +179,70 @@ def _compare_cp_with_tp(
157179
if multi_node_only and not VLLM_MULTI_NODE:
158180
pytest.skip("Not in multi-node setting")
159181

160-
common_args = [
182+
server_args = [
161183
# use half precision for speed and memory savings in CI environment
162184
"--dtype",
163185
"bfloat16",
164186
"--max-model-len",
165-
"2048",
187+
"4096",
166188
"--max-num-seqs",
167-
"8",
189+
"64",
168190
]
169191
if chunked_prefill:
170-
common_args.append("--enable-chunked-prefill")
192+
server_args.append("--enable-chunked-prefill")
171193
if eager_mode:
172-
common_args.append("--enforce-eager")
194+
server_args.append("--enforce-eager")
173195
if runner != "auto":
174-
common_args.extend(["--runner", runner])
196+
server_args.extend(["--runner", runner])
175197
if trust_remote_code:
176-
common_args.append("--trust-remote-code")
198+
server_args.append("--trust-remote-code")
177199
if tokenizer_mode:
178-
common_args.extend(["--tokenizer-mode", tokenizer_mode])
179-
if load_format:
180-
common_args.extend(["--load-format", load_format])
200+
server_args.extend(["--tokenizer-mode", tokenizer_mode])
181201
if hf_overrides:
182-
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
183-
184-
if not attn_backend:
185-
cp_env = tp_env = {}
186-
else:
187-
cp_env = tp_env = {
188-
"VLLM_ATTENTION_BACKEND": attn_backend,
189-
}
190-
191-
cp_args = [
192-
*common_args,
193-
"--tensor-parallel-size",
194-
str(tp_size),
195-
"--pipeline-parallel-size",
196-
str(pp_size),
197-
"--decode-context-parallel-size",
198-
str(dcp_size),
199-
"--dcp-kv-cache-interleave-size",
200-
str(cp_kv_cache_interleave_size),
201-
"--distributed-executor-backend",
202-
distributed_backend,
203-
]
202+
server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
203+
204+
server_args.extend(
205+
[
206+
"--tensor-parallel-size",
207+
str(tp_size),
208+
"--pipeline-parallel-size",
209+
str(pp_size),
210+
"--decode-context-parallel-size",
211+
str(dcp_size),
212+
"--dcp-kv-cache-interleave-size",
213+
str(cp_kv_cache_interleave_size),
214+
"--distributed-executor-backend",
215+
distributed_backend,
216+
]
217+
)
204218

205-
tp_args = [
206-
*common_args,
207-
"--tensor-parallel-size",
208-
str(tp_size),
209-
"--pipeline-parallel-size",
210-
str(pp_size),
211-
"--distributed-executor-backend",
212-
distributed_backend,
213-
]
219+
server_env = {}
220+
if attn_backend:
221+
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
214222

215-
compare_two_settings(
223+
with RemoteOpenAIServer(
216224
model_id,
217-
cp_args,
218-
tp_args,
219-
cp_env,
220-
tp_env,
221-
method=method,
225+
server_args,
226+
env_dict=server_env,
222227
max_wait_seconds=720,
223-
)
224-
225-
226-
CP_TEXT_GENERATION_MODELS = {
227-
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
228-
CPTestSettings.detailed(),
229-
CPTestSettings.detailed(tp_base=2),
230-
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
231-
],
232-
"bigcode/gpt_bigcode-santacoder": [
233-
CPTestSettings.detailed(),
234-
CPTestSettings.detailed(tp_base=2),
235-
],
236-
}
228+
) as remote_server:
229+
host = f"http://{remote_server.host}"
230+
port = remote_server.port
231+
232+
# Run GSM8K evaluation
233+
results = evaluate_gsm8k(
234+
num_questions=NUM_QUESTIONS,
235+
num_shots=NUM_SHOTS,
236+
host=host,
237+
port=port,
238+
)
237239

238-
CP_TEST_MODELS = [
239-
# TODO support other models
240-
# [LANGUAGE GENERATION]
241-
"deepseek-ai/DeepSeek-V2-Lite-Chat",
242-
"bigcode/gpt_bigcode-santacoder",
243-
]
240+
# Validate accuracy is reasonable
241+
accuracy = results["accuracy"]
242+
min_accuracy = MIN_ACCURACY[model_id]
243+
assert accuracy >= min_accuracy, (
244+
f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
245+
)
244246

245247

246248
@pytest.mark.parametrize(
@@ -274,12 +276,12 @@ def test_cp_generation(
274276
):
275277
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
276278
if (
277-
model_id == "bigcode/gpt_bigcode-santacoder"
279+
model_id == "Qwen/Qwen2.5-1.5B-Instruct"
278280
and torch.cuda.get_device_capability() != (9, 0)
279281
):
280282
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
281283

282-
_compare_cp_with_tp(
284+
_test_cp_gsm8k(
283285
model_id,
284286
parallel_setup,
285287
distributed_backend,

tests/models/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,11 @@ def check_available_online(
416416
trust_remote_code=True,
417417
),
418418
"Qwen2ForCausalLM": _HfExamplesInfo(
419-
"Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}
419+
"Qwen/Qwen2-0.5B-Instruct",
420+
extras={
421+
"2.5": "Qwen/Qwen2.5-0.5B-Instruct",
422+
"2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
423+
},
420424
),
421425
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
422426
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),

0 commit comments

Comments
 (0)