Skip to content

Commit 47803e3

Browse files
Formatting and allowed unused unpacked vals
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent de91ca2 commit 47803e3

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

test/python/fx_importer/basic_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,17 @@ def body(i, x):
255255
# CHECK-LABEL: test_flex_attention
256256
# CHECK: func.func @test_flex_attention
257257
def test_flex_attention():
258-
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
259-
from torch.nn.attention.flex_attention import BlockMask, _LARGE_SPARSE_BLOCK_SIZE, create_block_mask, flex_attention
258+
from torch._higher_order_ops.flex_attention import (
259+
flex_attention as flex_attention_hop,
260+
)
261+
from torch.nn.attention.flex_attention import (
262+
BlockMask,
263+
_LARGE_SPARSE_BLOCK_SIZE,
264+
create_block_mask,
265+
flex_attention,
266+
)
260267
from torch import Tensor
268+
261269
def _create_empty_block_mask(query: Tensor, key: Tensor):
262270
# Default block mask for flex attention.
263271
device = query.device
@@ -281,11 +289,13 @@ def relative_position_bias(
281289
class FlexAttention(torch.nn.Module):
282290
def __init__(self, block_mask):
283291
super().__init__()
284-
self.block_mask=block_mask
285-
292+
self.block_mask = block_mask
293+
286294
def forward(self, q, k, v):
287-
output, logsumexp = flex_attention_hop(
288-
q, k, v,
295+
output, logsumexp, *_ = flex_attention_hop(
296+
q,
297+
k,
298+
v,
289299
score_mod=relative_position_bias,
290300
block_mask=self.block_mask,
291301
scale=1.0,
@@ -299,7 +309,11 @@ def forward(self, q, k, v):
299309
k = torch.ones(B, Hkv, S, E)
300310
v = torch.ones(B, Hkv, S, Ev)
301311
m = fx.export_and_import(
302-
FlexAttention(_create_empty_block_mask(q,k)), q,k,v, func_name="test_flex_attention"
312+
FlexAttention(_create_empty_block_mask(q, k)),
313+
q,
314+
k,
315+
v,
316+
func_name="test_flex_attention",
303317
)
304318
print(m)
305319

0 commit comments

Comments
 (0)