1616import pytest
1717import torch
1818
19+ from tests .evals .gsm8k .gsm8k_eval import evaluate_gsm8k
20+ from tests .utils import RemoteOpenAIServer , create_new_process_for_each_test
1921from vllm .config .model import RunnerOption
2022from vllm .logger import init_logger
2123
2224from ..models .registry import HF_EXAMPLE_MODELS
23- from ..utils import compare_two_settings , create_new_process_for_each_test
2425
2526logger = init_logger ("test_context_parallel" )
2627
2728VLLM_MULTI_NODE = os .getenv ("VLLM_MULTI_NODE" , "0" ) == "1"
2829
30+ CP_TEST_MODELS = [
31+ # TODO support other models
32+ # [LANGUAGE GENERATION]
33+ "deepseek-ai/DeepSeek-V2-Lite-Chat" ,
34+ "Qwen/Qwen2.5-1.5B-Instruct" ,
35+ ]
36+
37+ # GSM8K eval configuration
38+ NUM_QUESTIONS = 256 # Fast eval for CI
39+ NUM_SHOTS = 5 # Few-shot examples
40+ # tp accuracy with 2% buffer
41+ MIN_ACCURACY = {
42+ # .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
43+ "deepseek-ai/DeepSeek-V2-Lite-Chat" : 0.64 ,
44+ # .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
45+ "Qwen/Qwen2.5-1.5B-Instruct" : 0.52 ,
46+ }
47+
2948
3049class ParallelSetup (NamedTuple ):
3150 tp_size : int
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
3857
3958class CPTestOptions (NamedTuple ):
4059 multi_node_only : bool
41- load_format : str | None = None
4260 attn_backend : str | None = None
4361
4462
@@ -54,17 +72,20 @@ def detailed(
5472 * ,
5573 tp_base : int = 4 ,
5674 pp_base : int = 1 ,
57- dcp_base : int = 1 ,
75+ dcp_multipliers : list [ float ] | None = None ,
5876 cp_kv_cache_interleave_size : int = 1 ,
5977 multi_node_only : bool = False ,
6078 runner : RunnerOption = "auto" ,
61- load_format : str | None = None ,
6279 attn_backend : str | None = None ,
6380 ):
6481 parallel_setups = []
82+ if dcp_multipliers is None :
83+ dcp_multipliers = [
84+ 0.5 ,
85+ ]
6586 for eager_mode_val in [False ]:
6687 for pp_multiplier in [1 ]:
67- for dcp_multiplier in [ 0.5 , 1 ] :
88+ for dcp_multiplier in dcp_multipliers :
6889 for chunked_prefill_val in [True ]:
6990 parallel_setups .append (
7091 ParallelSetup (
@@ -82,7 +103,6 @@ def detailed(
82103 runner = runner ,
83104 test_options = CPTestOptions (
84105 multi_node_only = multi_node_only ,
85- load_format = load_format ,
86106 attn_backend = attn_backend ,
87107 ),
88108 )
@@ -101,7 +121,24 @@ def iter_params(self, model_id: str):
101121 )
102122
103123
104- def _compare_cp_with_tp (
124+ CP_TEXT_GENERATION_MODELS = {
125+ "deepseek-ai/DeepSeek-V2-Lite-Chat" : [
126+ CPTestSettings .detailed (
127+ dcp_multipliers = [0.5 , 1 ], cp_kv_cache_interleave_size = 64
128+ ),
129+ ],
130+ "Qwen/Qwen2.5-1.5B-Instruct" : [
131+ CPTestSettings .detailed (
132+ cp_kv_cache_interleave_size = 16 , attn_backend = "FLASH_ATTN"
133+ ),
134+ CPTestSettings .detailed (
135+ cp_kv_cache_interleave_size = 16 , attn_backend = "FLASHINFER"
136+ ),
137+ ],
138+ }
139+
140+
141+ def _test_cp_gsm8k (
105142 model_id : str ,
106143 parallel_setup : ParallelSetup ,
107144 distributed_backend : str ,
@@ -121,7 +158,7 @@ def _compare_cp_with_tp(
121158 chunked_prefill ,
122159 ) = parallel_setup
123160
124- multi_node_only , load_format , attn_backend = test_options
161+ multi_node_only , attn_backend = test_options
125162
126163 model_info = HF_EXAMPLE_MODELS .find_hf_info (model_id )
127164 model_info .check_transformers_version (on_fail = "skip" )
@@ -130,22 +167,7 @@ def _compare_cp_with_tp(
130167 tokenizer_mode = model_info .tokenizer_mode
131168 hf_overrides = model_info .hf_overrides
132169
133- if load_format == "dummy" :
134- # Avoid OOM
135- text_overrides = {
136- "num_hidden_layers" : 4 ,
137- "hidden_size" : 512 ,
138- "intermediate_size" : 800 ,
139- "num_attention_heads" : 4 ,
140- "num_key_value_heads" : 1 ,
141- }
142-
143- if is_multimodal :
144- hf_overrides .update ({"text_config" : text_overrides })
145- else :
146- hf_overrides .update (text_overrides )
147- else :
148- model_info .check_available_online (on_fail = "skip" )
170+ model_info .check_available_online (on_fail = "skip" )
149171
150172 if num_gpus_available < tp_size * pp_size :
151173 pytest .skip (f"Need at least { tp_size } x { pp_size } GPUs" )
@@ -157,90 +179,70 @@ def _compare_cp_with_tp(
157179 if multi_node_only and not VLLM_MULTI_NODE :
158180 pytest .skip ("Not in multi-node setting" )
159181
160- common_args = [
182+ server_args = [
161183 # use half precision for speed and memory savings in CI environment
162184 "--dtype" ,
163185 "bfloat16" ,
164186 "--max-model-len" ,
165- "2048 " ,
187+ "4096 " ,
166188 "--max-num-seqs" ,
167- "8 " ,
189+ "64 " ,
168190 ]
169191 if chunked_prefill :
170- common_args .append ("--enable-chunked-prefill" )
192+ server_args .append ("--enable-chunked-prefill" )
171193 if eager_mode :
172- common_args .append ("--enforce-eager" )
194+ server_args .append ("--enforce-eager" )
173195 if runner != "auto" :
174- common_args .extend (["--runner" , runner ])
196+ server_args .extend (["--runner" , runner ])
175197 if trust_remote_code :
176- common_args .append ("--trust-remote-code" )
198+ server_args .append ("--trust-remote-code" )
177199 if tokenizer_mode :
178- common_args .extend (["--tokenizer-mode" , tokenizer_mode ])
179- if load_format :
180- common_args .extend (["--load-format" , load_format ])
200+ server_args .extend (["--tokenizer-mode" , tokenizer_mode ])
181201 if hf_overrides :
182- common_args .extend (["--hf-overrides" , json .dumps (hf_overrides )])
183-
184- if not attn_backend :
185- cp_env = tp_env = {}
186- else :
187- cp_env = tp_env = {
188- "VLLM_ATTENTION_BACKEND" : attn_backend ,
189- }
190-
191- cp_args = [
192- * common_args ,
193- "--tensor-parallel-size" ,
194- str (tp_size ),
195- "--pipeline-parallel-size" ,
196- str (pp_size ),
197- "--decode-context-parallel-size" ,
198- str (dcp_size ),
199- "--dcp-kv-cache-interleave-size" ,
200- str (cp_kv_cache_interleave_size ),
201- "--distributed-executor-backend" ,
202- distributed_backend ,
203- ]
202+ server_args .extend (["--hf-overrides" , json .dumps (hf_overrides )])
203+
204+ server_args .extend (
205+ [
206+ "--tensor-parallel-size" ,
207+ str (tp_size ),
208+ "--pipeline-parallel-size" ,
209+ str (pp_size ),
210+ "--decode-context-parallel-size" ,
211+ str (dcp_size ),
212+ "--dcp-kv-cache-interleave-size" ,
213+ str (cp_kv_cache_interleave_size ),
214+ "--distributed-executor-backend" ,
215+ distributed_backend ,
216+ ]
217+ )
204218
205- tp_args = [
206- * common_args ,
207- "--tensor-parallel-size" ,
208- str (tp_size ),
209- "--pipeline-parallel-size" ,
210- str (pp_size ),
211- "--distributed-executor-backend" ,
212- distributed_backend ,
213- ]
219+ server_env = {}
220+ if attn_backend :
221+ server_env ["VLLM_ATTENTION_BACKEND" ] = attn_backend
214222
215- compare_two_settings (
223+ with RemoteOpenAIServer (
216224 model_id ,
217- cp_args ,
218- tp_args ,
219- cp_env ,
220- tp_env ,
221- method = method ,
225+ server_args ,
226+ env_dict = server_env ,
222227 max_wait_seconds = 720 ,
223- )
224-
225-
226- CP_TEXT_GENERATION_MODELS = {
227- "deepseek-ai/DeepSeek-V2-Lite-Chat" : [
228- CPTestSettings .detailed (),
229- CPTestSettings .detailed (tp_base = 2 ),
230- CPTestSettings .detailed (tp_base = 2 , cp_kv_cache_interleave_size = 64 ),
231- ],
232- "bigcode/gpt_bigcode-santacoder" : [
233- CPTestSettings .detailed (),
234- CPTestSettings .detailed (tp_base = 2 ),
235- ],
236- }
228+ ) as remote_server :
229+ host = f"http://{ remote_server .host } "
230+ port = remote_server .port
231+
232+ # Run GSM8K evaluation
233+ results = evaluate_gsm8k (
234+ num_questions = NUM_QUESTIONS ,
235+ num_shots = NUM_SHOTS ,
236+ host = host ,
237+ port = port ,
238+ )
237239
238- CP_TEST_MODELS = [
239- # TODO support other models
240- # [LANGUAGE GENERATION ]
241- "deepseek-ai/DeepSeek-V2-Lite-Chat" ,
242- "bigcode/gpt_bigcode-santacoder" ,
243- ]
240+ # Validate accuracy is reasonable
241+ accuracy = results [ "accuracy" ]
242+ min_accuracy = MIN_ACCURACY [ model_id ]
243+ assert accuracy >= min_accuracy , (
244+ f"TP+DCP accuracy too low: { accuracy :.3f } < { min_accuracy :.3f } "
245+ )
244246
245247
246248@pytest .mark .parametrize (
@@ -274,12 +276,12 @@ def test_cp_generation(
274276 ):
275277 pytest .skip (reason = "MLA+DCP requires compute capability of 9.0 or higher" )
276278 if (
277- model_id == "bigcode/gpt_bigcode-santacoder "
279+ model_id == "Qwen/Qwen2.5-1.5B-Instruct "
278280 and torch .cuda .get_device_capability () != (9 , 0 )
279281 ):
280282 pytest .skip (reason = "GQA+DCP currently requires compute capability of 9.0" )
281283
282- _compare_cp_with_tp (
284+ _test_cp_gsm8k (
283285 model_id ,
284286 parallel_setup ,
285287 distributed_backend ,
0 commit comments