Skip to content

Commit 237c8ea

Browse files
MatthewBonanniFeiDaLI
authored andcommitted
Add FLASHINFER_MLA to backend selector test (vllm-project#24753)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent aaf0a1c commit 237c8ea

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def clear_cache():
2222

2323
# Define MLA and non-MLA backends separately
2424
DEVICE_MLA_BACKENDS = {
25-
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
25+
"cuda": [
26+
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
27+
"CUTLASS_MLA"
28+
],
2629
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
2730
"cpu": [],
2831
}
@@ -90,8 +93,8 @@ def test_env(
9093

9194
with patch("vllm.attention.selector.current_platform",
9295
CpuPlatform()):
93-
backend = get_attn_backend(16, torch.float16, torch.float16,
94-
block_size, False)
96+
backend = get_attn_backend(16, torch.float16, None, block_size,
97+
False)
9598
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
9699

97100
elif device == "hip":
@@ -109,7 +112,7 @@ def test_env(
109112
with pytest.raises(ValueError) as exc_info:
110113
get_attn_backend(16,
111114
torch.float16,
112-
torch.float16,
115+
None,
113116
block_size,
114117
False,
115118
use_mla=use_mla)
@@ -120,7 +123,7 @@ def test_env(
120123
with pytest.raises(ValueError) as exc_info:
121124
get_attn_backend(16,
122125
torch.float16,
123-
torch.float16,
126+
None,
124127
block_size,
125128
False,
126129
use_mla=use_mla)
@@ -130,7 +133,7 @@ def test_env(
130133
# Valid backend-block_size combination
131134
backend = get_attn_backend(16,
132135
torch.float16,
133-
torch.float16,
136+
None,
134137
block_size,
135138
False,
136139
use_mla=use_mla)
@@ -139,7 +142,7 @@ def test_env(
139142
else:
140143
backend = get_attn_backend(16,
141144
torch.float16,
142-
torch.float16,
145+
None,
143146
block_size,
144147
False,
145148
use_mla=use_mla)
@@ -153,6 +156,8 @@ def test_env(
153156
# CUDA MLA backend logic:
154157
# - CUTLASS_MLA: only supported with block_size == 128
155158
# and Blackwell GPUs (SM 10.0), V1 only
159+
# - FLASHINFER_MLA: only supported on Blackwell GPUs
160+
# (SM 10.0+), V1 only
156161
# - FLASHMLA: only supported with block_size == 64
157162
# - FLASH_ATTN_MLA: V1 only
158163
# - TRITON_MLA: fallback for other cases
@@ -169,12 +174,31 @@ def test_env(
169174
else:
170175
backend = get_attn_backend(16,
171176
torch.float16,
172-
torch.float16,
177+
None,
173178
block_size,
174179
False,
175180
use_mla=use_mla)
176181
expected = "CUTLASS_MLA_VLLM_V1"
177182
assert backend.get_name() == expected
183+
elif name == "FLASHINFER_MLA":
184+
if not use_v1:
185+
# FlashInfer MLA only supported on V1 engine
186+
pytest.skip(
187+
"FlashInfer MLA only supported on V1 engine")
188+
elif block_size not in [32, 64]:
189+
# FlashInfer MLA only supports block_size 32 or 64
190+
pytest.skip(
191+
"FlashInfer MLA only supports block_size 32 "
192+
"or 64")
193+
else:
194+
backend = get_attn_backend(16,
195+
torch.float16,
196+
None,
197+
block_size,
198+
False,
199+
use_mla=use_mla)
200+
expected = "FLASHINFER_MLA"
201+
assert backend.get_name() == expected
178202
elif name == "FLASHMLA":
179203
if block_size != 64:
180204
# FlashMLA only supports block_size == 64
@@ -189,7 +213,7 @@ def test_env(
189213
else:
190214
backend = get_attn_backend(16,
191215
torch.float16,
192-
torch.float16,
216+
None,
193217
block_size,
194218
False,
195219
use_mla=use_mla)
@@ -204,7 +228,7 @@ def test_env(
204228
else:
205229
backend = get_attn_backend(16,
206230
torch.float16,
207-
torch.float16,
231+
None,
208232
block_size,
209233
False,
210234
use_mla=use_mla)
@@ -214,7 +238,7 @@ def test_env(
214238
# TRITON_MLA or other fallback
215239
backend = get_attn_backend(16,
216240
torch.float16,
217-
torch.float16,
241+
None,
218242
block_size,
219243
False,
220244
use_mla=use_mla)
@@ -224,7 +248,7 @@ def test_env(
224248
elif name == "FLASHINFER":
225249
backend = get_attn_backend(16,
226250
torch.float16,
227-
torch.float16,
251+
None,
228252
block_size,
229253
False,
230254
use_mla=use_mla)
@@ -233,7 +257,7 @@ def test_env(
233257
else:
234258
backend = get_attn_backend(32,
235259
torch.float16,
236-
torch.float16,
260+
None,
237261
block_size,
238262
False,
239263
use_mla=use_mla)
@@ -243,7 +267,7 @@ def test_env(
243267
if use_v1:
244268
backend = get_attn_backend(16,
245269
torch.float16,
246-
torch.float16,
270+
None,
247271
block_size,
248272
False,
249273
use_mla=use_mla)
@@ -269,15 +293,13 @@ def test_fp32_fallback(
269293

270294
with patch("vllm.attention.selector.current_platform",
271295
CpuPlatform()):
272-
backend = get_attn_backend(16, torch.float32, torch.float32,
273-
16, False)
296+
backend = get_attn_backend(16, torch.float32, None, 16, False)
274297
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
275298

276299
elif device == "cuda":
277300
with patch("vllm.attention.selector.current_platform",
278301
CudaPlatform()):
279-
backend = get_attn_backend(16, torch.float32, torch.float32,
280-
16, False)
302+
backend = get_attn_backend(16, torch.float32, None, 16, False)
281303
assert (backend.get_name() == "FLEX_ATTENTION"
282304
if use_v1 else "XFORMERS")
283305

@@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
331353
assert backend.get_name() != STR_FLASH_ATTN_VAL
332354

333355
# Attention-free models should bypass env and use PlaceholderAttention
334-
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
356+
backend = get_attn_backend(16, torch.float16, None, 16, True)
335357
assert backend.get_name() != STR_FLASH_ATTN_VAL
336358

337359

tests/v1/attention/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
141141
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
142142
_Backend.FLASH_ATTN_MLA:
143143
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
144+
_Backend.FLASHINFER_MLA:
145+
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
144146
_Backend.TRITON_MLA_VLLM_V1:
145147
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
146148
}

0 commit comments

Comments
 (0)