Skip to content

Commit d7284a2

Browse files
[Core] Rename PassConfig flags as per RFC #27995 (#29646)
Signed-off-by: arpitkh101 <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 506ed87 commit d7284a2

22 files changed

+318
-123
lines changed

tests/compile/distributed/test_async_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
326326
vllm_config = VllmConfig()
327327
vllm_config.compilation_config = CompilationConfig(
328328
pass_config=PassConfig(
329-
enable_async_tp=True,
329+
fuse_gemm_comms=True,
330330
),
331331
)
332332
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
@@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
413413
"mode": CompilationMode.VLLM_COMPILE,
414414
"compile_sizes": [2, 4, 8],
415415
"splitting_ops": [],
416-
"pass_config": {"enable_async_tp": async_tp_enabled},
416+
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
417417
}
418418

419419
async_tp_args = [

tests/compile/distributed/test_fusion_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
295295
)
296296
)
297297
vllm_config.compilation_config.pass_config = PassConfig(
298-
enable_fi_allreduce_fusion=True, enable_noop=True
298+
fuse_allreduce_rms=True, eliminate_noops=True
299299
)
300300
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
301301
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path

tests/compile/distributed/test_fusions_e2e.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_attn_quant(
192192
splitting_ops=splitting_ops,
193193
# Common
194194
mode=CompilationMode.VLLM_COMPILE,
195-
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
195+
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
196196
# Inductor caches custom passes by default as well via uuid
197197
inductor_compile_config={"force_disable_caches": True},
198198
)
@@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
282282
# Common
283283
mode=CompilationMode.VLLM_COMPILE,
284284
pass_config=PassConfig(
285-
enable_attn_fusion=True,
286-
enable_noop=True,
287-
enable_fi_allreduce_fusion=True,
285+
fuse_attn_quant=True,
286+
eliminate_noops=True,
287+
fuse_allreduce_rms=True,
288288
),
289289
# Inductor caches custom passes by default as well via uuid
290290
inductor_compile_config={"force_disable_caches": True},
@@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
384384
# Common
385385
level=CompilationMode.VLLM_COMPILE,
386386
pass_config=PassConfig(
387-
enable_attn_fusion=True,
388-
enable_noop=True,
389-
enable_sequence_parallelism=True,
390-
enable_async_tp=True,
387+
fuse_attn_quant=True,
388+
eliminate_noops=True,
389+
enable_sp=True,
390+
fuse_gemm_comms=True,
391391
),
392392
# Inductor caches custom passes by default as well via uuid
393393
inductor_compile_config={"force_disable_caches": True},

tests/compile/distributed/test_sequence_parallelism.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def ops_in_model_before(self):
153153
]
154154

155155
def ops_in_model(self):
156-
if self.vllm_config.compilation_config.pass_config.enable_fusion:
156+
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
157157
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
158158
elif RMSNorm.enabled():
159159
return [
@@ -183,7 +183,7 @@ def ops_in_model(self):
183183
@pytest.mark.parametrize("seq_len", [16])
184184
@pytest.mark.parametrize("hidden_size", [16])
185185
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186-
@pytest.mark.parametrize("enable_fusion", [True, False])
186+
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
187187
@pytest.mark.parametrize("dynamic", [False, True])
188188
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
189189
def test_sequence_parallelism_pass(
@@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
193193
seq_len: int,
194194
hidden_size: int,
195195
dtype: torch.dtype,
196-
enable_fusion: bool,
196+
fuse_norm_quant: bool,
197197
dynamic: bool,
198198
):
199199
num_processes = 2
@@ -211,7 +211,7 @@ def run_torch_spawn(fn, nprocs):
211211
seq_len,
212212
hidden_size,
213213
dtype,
214-
enable_fusion,
214+
fuse_norm_quant,
215215
dynamic,
216216
),
217217
nprocs=nprocs,
@@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
229229
seq_len: int,
230230
hidden_size: int,
231231
dtype: torch.dtype,
232-
enable_fusion: bool,
232+
fuse_norm_quant: bool,
233233
dynamic: bool,
234234
):
235235
current_platform.seed_everything(0)
@@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
260260
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
261261
custom_ops=custom_ops_list,
262262
pass_config=PassConfig(
263-
enable_sequence_parallelism=True,
264-
enable_fusion=enable_fusion,
265-
enable_noop=True,
263+
enable_sp=True,
264+
fuse_norm_quant=fuse_norm_quant,
265+
eliminate_noops=True,
266266
),
267267
) # NoOp needed for fusion
268268
device_config = DeviceConfig(device=torch.device("cuda"))
@@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
297297
sequence_parallelism_pass,
298298
]
299299

300-
if enable_fusion:
300+
if fuse_norm_quant:
301301
fusion_pass = RMSNormQuantFusionPass(vllm_config)
302302
passes_for_backend.append(fusion_pass)
303303

tests/compile/fullgraph/test_full_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def test_full_graph(
122122
CompilationConfig(
123123
mode=CompilationMode.VLLM_COMPILE,
124124
custom_ops=["+rms_norm"],
125-
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
125+
pass_config=PassConfig(
126+
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
127+
),
126128
),
127129
*model_info,
128130
)

tests/compile/test_config.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
4+
import logging
45
from contextlib import nullcontext
56
from unittest.mock import patch
67

@@ -10,8 +11,9 @@
1011
from vllm.compilation.counter import compilation_counter
1112
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
1213
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
13-
from vllm.config.compilation import CompilationMode
14+
from vllm.config.compilation import CompilationMode, PassConfig
1415
from vllm.engine.arg_utils import EngineArgs
16+
from vllm.logger import _print_warning_once
1517
from vllm.platforms import current_platform
1618
from vllm.utils.torch_utils import _is_torch_equal_or_newer
1719

@@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
191193
config = VllmConfig(
192194
compilation_config=CompilationConfig(
193195
mode=CompilationMode.VLLM_COMPILE,
194-
pass_config={"enable_attn_fusion": True, "enable_noop": True},
196+
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
195197
custom_ops=["+quant_fp8"],
196198
cudagraph_mode=CUDAGraphMode.PIECEWISE,
197199
)
@@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
206208
config = VllmConfig(
207209
compilation_config=CompilationConfig(
208210
mode=CompilationMode.VLLM_COMPILE,
209-
pass_config={"enable_attn_fusion": True, "enable_noop": True},
211+
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
210212
custom_ops=["+quant_fp8"],
211213
cudagraph_mode=CUDAGraphMode.PIECEWISE,
212214
# work around for accessing all attntion ops
@@ -219,15 +221,15 @@ def test_splitting_ops_dynamic():
219221
compilation_config=CompilationConfig(
220222
mode=CompilationMode.VLLM_COMPILE,
221223
use_inductor_graph_partition=True,
222-
pass_config={"enable_attn_fusion": True, "enable_noop": True},
224+
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
223225
custom_ops=["+quant_fp8"],
224226
cudagraph_mode=CUDAGraphMode.PIECEWISE,
225227
)
226228
)
227229
# With inductor graph partition, attn_fusion and splitting_ops
228230
# work together. Default splitting_ops include attention ops.
229231
assert config.compilation_config.splitting_ops_contain_attention()
230-
# enable_attn_fusion is directly supported under
232+
# fuse_attn_quant is directly supported under
231233
# use_inductor_graph_partition=True, and cudagraph_mode
232234
# is unchanged.
233235
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
@@ -301,7 +303,7 @@ def test_should_split():
301303
"cudagraph_capture_sizes",
302304
"max_cudagraph_capture_size",
303305
"tp_size",
304-
"enable_sequence_parallelism",
306+
"enable_sp",
305307
"max_num_batched_tokens",
306308
"cudagraph_mode",
307309
"expected_max_size",
@@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
339341
cudagraph_capture_sizes,
340342
max_cudagraph_capture_size,
341343
tp_size,
342-
enable_sequence_parallelism,
344+
enable_sp,
343345
max_num_batched_tokens,
344346
cudagraph_mode,
345347
expected_max_size,
@@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
355357
compilation_config = CompilationConfig(
356358
cudagraph_capture_sizes=cudagraph_capture_sizes,
357359
max_cudagraph_capture_size=max_cudagraph_capture_size,
358-
pass_config={
359-
"enable_sequence_parallelism": enable_sequence_parallelism,
360-
"enable_fusion": True,
361-
"enable_noop": True,
362-
},
360+
pass_config=PassConfig(
361+
enable_sp=enable_sp,
362+
fuse_norm_quant=True,
363+
fuse_act_quant=True,
364+
eliminate_noops=True,
365+
),
363366
cudagraph_mode=cudagraph_mode,
364367
)
365368
engine_args = EngineArgs(
@@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
375378
vllm_config.compilation_config.max_cudagraph_capture_size
376379
== expected_max_size
377380
)
381+
382+
383+
def test_pass_config_deprecation(caplog_vllm):
384+
caplog_vllm.set_level(logging.WARNING)
385+
386+
# Clear cache to ensure warnings are re-issued
387+
_print_warning_once.cache_clear()
388+
389+
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
390+
caplog_vllm.clear()
391+
config = PassConfig(enable_fusion=True)
392+
assert "enable_fusion is deprecated" in caplog_vllm.text
393+
assert config.fuse_norm_quant is True
394+
assert config.fuse_act_quant is True
395+
assert config.enable_fusion is None
396+
397+
# Test enable_attn_fusion -> fuse_attn_quant
398+
caplog_vllm.clear()
399+
config = PassConfig(enable_attn_fusion=True)
400+
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
401+
assert config.fuse_attn_quant is True
402+
assert config.enable_attn_fusion is None
403+
404+
# Test enable_noop -> eliminate_noops
405+
caplog_vllm.clear()
406+
config = PassConfig(enable_noop=True)
407+
assert "enable_noop is deprecated" in caplog_vllm.text
408+
assert config.eliminate_noops is True
409+
assert config.enable_noop is None
410+
411+
# Test enable_sequence_parallelism -> enable_sp
412+
caplog_vllm.clear()
413+
config = PassConfig(enable_sequence_parallelism=True)
414+
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
415+
assert config.enable_sp is True
416+
assert config.enable_sequence_parallelism is None
417+
418+
# Test enable_async_tp -> fuse_gemm_comms
419+
caplog_vllm.clear()
420+
config = PassConfig(enable_async_tp=True)
421+
assert "enable_async_tp is deprecated" in caplog_vllm.text
422+
assert config.fuse_gemm_comms is True
423+
assert config.enable_async_tp is None
424+
425+
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
426+
caplog_vllm.clear()
427+
config = PassConfig(enable_fi_allreduce_fusion=True)
428+
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
429+
assert config.fuse_allreduce_rms is True
430+
assert config.enable_fi_allreduce_fusion is None

tests/compile/test_functionalization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,11 @@ def test_fix_functionalization(
223223
model_config=ModelConfig(dtype=dtype),
224224
compilation_config=CompilationConfig(
225225
custom_ops=["all"],
226-
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
226+
pass_config=PassConfig(
227+
fuse_norm_quant=do_fusion,
228+
fuse_act_quant=do_fusion,
229+
eliminate_noops=True,
230+
),
227231
),
228232
)
229233

tests/compile/test_fusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
159159
compilation_config=CompilationConfig(
160160
mode=CompilationMode.VLLM_COMPILE,
161161
custom_ops=custom_ops,
162-
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
162+
pass_config=PassConfig(
163+
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
164+
),
163165
),
164166
)
165167
with vllm.config.set_current_vllm_config(vllm_config):

tests/compile/test_fusion_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_attention_quant_pattern(
373373

374374
# Run model with attn fusion enabled
375375
vllm_config.compilation_config.pass_config = PassConfig(
376-
enable_attn_fusion=True, enable_noop=True
376+
fuse_attn_quant=True, eliminate_noops=True
377377
)
378378
with (
379379
set_current_vllm_config(vllm_config),

tests/compile/test_noop_elimination.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def forward(self, x):
5151
vllm_config = VllmConfig(
5252
compilation_config=CompilationConfig(
5353
mode=CompilationMode.VLLM_COMPILE,
54-
pass_config=PassConfig(enable_noop=True),
54+
pass_config=PassConfig(eliminate_noops=True),
5555
)
5656
)
5757
with vllm.config.set_current_vllm_config(vllm_config):
@@ -99,7 +99,7 @@ def forward(self, x):
9999
vllm_config = VllmConfig(
100100
compilation_config=CompilationConfig(
101101
mode=CompilationMode.VLLM_COMPILE,
102-
pass_config=PassConfig(enable_noop=True),
102+
pass_config=PassConfig(eliminate_noops=True),
103103
)
104104
)
105105
with vllm.config.set_current_vllm_config(vllm_config):

0 commit comments

Comments
 (0)