@@ -22,7 +22,10 @@ def clear_cache():
2222
2323# Define MLA and non-MLA backends separately
2424DEVICE_MLA_BACKENDS = {
25- "cuda" : ["TRITON_MLA" , "FLASHMLA" , "FLASH_ATTN_MLA" , "CUTLASS_MLA" ],
25+ "cuda" : [
26+ "TRITON_MLA" , "FLASHMLA" , "FLASHINFER_MLA" , "FLASH_ATTN_MLA" ,
27+ "CUTLASS_MLA"
28+ ],
2629 "hip" : ["TRITON_MLA" , "ROCM_AITER_MLA" ],
2730 "cpu" : [],
2831}
@@ -90,8 +93,8 @@ def test_env(
9093
9194 with patch ("vllm.attention.selector.current_platform" ,
9295 CpuPlatform ()):
93- backend = get_attn_backend (16 , torch .float16 , torch . float16 ,
94- block_size , False )
96+ backend = get_attn_backend (16 , torch .float16 , None , block_size ,
97+ False )
9598 assert backend .get_name () == "TORCH_SDPA_VLLM_V1"
9699
97100 elif device == "hip" :
@@ -109,7 +112,7 @@ def test_env(
109112 with pytest .raises (ValueError ) as exc_info :
110113 get_attn_backend (16 ,
111114 torch .float16 ,
112- torch . float16 ,
115+ None ,
113116 block_size ,
114117 False ,
115118 use_mla = use_mla )
@@ -120,7 +123,7 @@ def test_env(
120123 with pytest .raises (ValueError ) as exc_info :
121124 get_attn_backend (16 ,
122125 torch .float16 ,
123- torch . float16 ,
126+ None ,
124127 block_size ,
125128 False ,
126129 use_mla = use_mla )
@@ -130,7 +133,7 @@ def test_env(
130133 # Valid backend-block_size combination
131134 backend = get_attn_backend (16 ,
132135 torch .float16 ,
133- torch . float16 ,
136+ None ,
134137 block_size ,
135138 False ,
136139 use_mla = use_mla )
@@ -139,7 +142,7 @@ def test_env(
139142 else :
140143 backend = get_attn_backend (16 ,
141144 torch .float16 ,
142- torch . float16 ,
145+ None ,
143146 block_size ,
144147 False ,
145148 use_mla = use_mla )
@@ -153,6 +156,8 @@ def test_env(
153156 # CUDA MLA backend logic:
154157 # - CUTLASS_MLA: only supported with block_size == 128
155158 # and Blackwell GPUs (SM 10.0), V1 only
159+ # - FLASHINFER_MLA: only supported on Blackwell GPUs
160+ # (SM 10.0+), V1 only
156161 # - FLASHMLA: only supported with block_size == 64
157162 # - FLASH_ATTN_MLA: V1 only
158163 # - TRITON_MLA: fallback for other cases
@@ -169,12 +174,31 @@ def test_env(
169174 else :
170175 backend = get_attn_backend (16 ,
171176 torch .float16 ,
172- torch . float16 ,
177+ None ,
173178 block_size ,
174179 False ,
175180 use_mla = use_mla )
176181 expected = "CUTLASS_MLA_VLLM_V1"
177182 assert backend .get_name () == expected
183+ elif name == "FLASHINFER_MLA" :
184+ if not use_v1 :
185+ # FlashInfer MLA only supported on V1 engine
186+ pytest .skip (
187+ "FlashInfer MLA only supported on V1 engine" )
188+ elif block_size not in [32 , 64 ]:
189+ # FlashInfer MLA only supports block_size 32 or 64
190+ pytest .skip (
191+ "FlashInfer MLA only supports block_size 32 "
192+ "or 64" )
193+ else :
194+ backend = get_attn_backend (16 ,
195+ torch .float16 ,
196+ None ,
197+ block_size ,
198+ False ,
199+ use_mla = use_mla )
200+ expected = "FLASHINFER_MLA"
201+ assert backend .get_name () == expected
178202 elif name == "FLASHMLA" :
179203 if block_size != 64 :
180204 # FlashMLA only supports block_size == 64
@@ -189,7 +213,7 @@ def test_env(
189213 else :
190214 backend = get_attn_backend (16 ,
191215 torch .float16 ,
192- torch . float16 ,
216+ None ,
193217 block_size ,
194218 False ,
195219 use_mla = use_mla )
@@ -204,7 +228,7 @@ def test_env(
204228 else :
205229 backend = get_attn_backend (16 ,
206230 torch .float16 ,
207- torch . float16 ,
231+ None ,
208232 block_size ,
209233 False ,
210234 use_mla = use_mla )
@@ -214,7 +238,7 @@ def test_env(
214238 # TRITON_MLA or other fallback
215239 backend = get_attn_backend (16 ,
216240 torch .float16 ,
217- torch . float16 ,
241+ None ,
218242 block_size ,
219243 False ,
220244 use_mla = use_mla )
@@ -224,7 +248,7 @@ def test_env(
224248 elif name == "FLASHINFER" :
225249 backend = get_attn_backend (16 ,
226250 torch .float16 ,
227- torch . float16 ,
251+ None ,
228252 block_size ,
229253 False ,
230254 use_mla = use_mla )
@@ -233,7 +257,7 @@ def test_env(
233257 else :
234258 backend = get_attn_backend (32 ,
235259 torch .float16 ,
236- torch . float16 ,
260+ None ,
237261 block_size ,
238262 False ,
239263 use_mla = use_mla )
@@ -243,7 +267,7 @@ def test_env(
243267 if use_v1 :
244268 backend = get_attn_backend (16 ,
245269 torch .float16 ,
246- torch . float16 ,
270+ None ,
247271 block_size ,
248272 False ,
249273 use_mla = use_mla )
@@ -269,15 +293,13 @@ def test_fp32_fallback(
269293
270294 with patch ("vllm.attention.selector.current_platform" ,
271295 CpuPlatform ()):
272- backend = get_attn_backend (16 , torch .float32 , torch .float32 ,
273- 16 , False )
296+ backend = get_attn_backend (16 , torch .float32 , None , 16 , False )
274297 assert backend .get_name () == "TORCH_SDPA_VLLM_V1"
275298
276299 elif device == "cuda" :
277300 with patch ("vllm.attention.selector.current_platform" ,
278301 CudaPlatform ()):
279- backend = get_attn_backend (16 , torch .float32 , torch .float32 ,
280- 16 , False )
302+ backend = get_attn_backend (16 , torch .float32 , None , 16 , False )
281303 assert (backend .get_name () == "FLEX_ATTENTION"
282304 if use_v1 else "XFORMERS" )
283305
@@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
331353 assert backend .get_name () != STR_FLASH_ATTN_VAL
332354
333355 # Attention-free models should bypass env and use PlaceholderAttention
334- backend = get_attn_backend (16 , torch .float16 , torch . float16 , 16 , True )
356+ backend = get_attn_backend (16 , torch .float16 , None , 16 , True )
335357 assert backend .get_name () != STR_FLASH_ATTN_VAL
336358
337359
0 commit comments