@@ -255,9 +255,17 @@ def body(i, x):
255255# CHECK-LABEL: test_flex_attention
256256# CHECK: func.func @test_flex_attention
257257def test_flex_attention ():
258- from torch ._higher_order_ops .flex_attention import flex_attention as flex_attention_hop
259- from torch .nn .attention .flex_attention import BlockMask , _LARGE_SPARSE_BLOCK_SIZE , create_block_mask , flex_attention
258+ from torch ._higher_order_ops .flex_attention import (
259+ flex_attention as flex_attention_hop ,
260+ )
261+ from torch .nn .attention .flex_attention import (
262+ BlockMask ,
263+ _LARGE_SPARSE_BLOCK_SIZE ,
264+ create_block_mask ,
265+ flex_attention ,
266+ )
260267 from torch import Tensor
268+
261269 def _create_empty_block_mask (query : Tensor , key : Tensor ):
262270 # Default block mask for flex attention.
263271 device = query .device
@@ -281,11 +289,13 @@ def relative_position_bias(
281289 class FlexAttention (torch .nn .Module ):
282290 def __init__ (self , block_mask ):
283291 super ().__init__ ()
284- self .block_mask = block_mask
285-
292+ self .block_mask = block_mask
293+
286294 def forward (self , q , k , v ):
287- output , logsumexp = flex_attention_hop (
288- q , k , v ,
295+ output , logsumexp , * _ = flex_attention_hop (
296+ q ,
297+ k ,
298+ v ,
289299 score_mod = relative_position_bias ,
290300 block_mask = self .block_mask ,
291301 scale = 1.0 ,
@@ -299,7 +309,11 @@ def forward(self, q, k, v):
299309 k = torch .ones (B , Hkv , S , E )
300310 v = torch .ones (B , Hkv , S , Ev )
301311 m = fx .export_and_import (
302- FlexAttention (_create_empty_block_mask (q ,k )), q ,k ,v , func_name = "test_flex_attention"
312+ FlexAttention (_create_empty_block_mask (q , k )),
313+ q ,
314+ k ,
315+ v ,
316+ func_name = "test_flex_attention" ,
303317 )
304318 print (m )
305319
0 commit comments