Skip to content

Commit 16fc70c

Browse files
Renamed aten.flex_attention -> hop_flex_attention; Added more lit tests
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent acc3ade commit 16fc70c

File tree

4 files changed

+64
-28
lines changed

4 files changed

+64
-28
lines changed

include/torch-mlir/Dialect/Torch/IR/TorchOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,19 +1445,19 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
14451445
//===----------------------------------------------------------------------===//
14461446
// FlexAttention operation
14471447

1448-
// NOTE: This op is manually defined because `aten::flex_attention` exists in
1448+
// NOTE: This op is manually defined because flex_attention exists in
14491449
// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet
14501450
// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script
14511451
// validates against the JIT registry, so it cannot auto-generate this op.
14521452
// Once PyTorch adds flex_attention to the JIT registry, this can be moved to
14531453
// the auto-generated section.
14541454
//===----------------------------------------------------------------------===//
1455-
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
1455+
def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [
14561456
AllowsTypeRefinement,
14571457
HasValueSemantics,
14581458
ReadOnly
14591459
]> {
1460-
let summary = "Generated op for `aten::flex_attention`";
1460+
let summary = "Computes the flex_attention operation (1-1 with torch._higher_order_ops.flex_attention)";
14611461
let description = [{
14621462
FlexAttention operation with flexible block-sparse attention patterns.
14631463

@@ -1499,10 +1499,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14991499

15001500
let hasCustomAssemblyFormat = 1;
15011501
let extraClassDefinition = [{
1502-
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
1502+
ParseResult HigherOrderFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
15031503
return parseDefaultTorchOp(parser, result, 6, 3);
15041504
}
1505-
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
1505+
void HigherOrderFlexAttentionOp::print(OpAsmPrinter &printer) {
15061506
printDefaultTorchOp(printer, *this, 6, 3);
15071507
}
15081508
}];

python/torch_mlir/extras/fx_importer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,7 +1918,7 @@ def _import_hop_flex_attention(
19181918
- kernel_options: Optional Dict of performance tuning options:
19191919
- return_lse: Boolean for whether to return the log-sum-exp tensor
19201920
1921-
This creates a call to aten.flex_attention with function symbol references for
1921+
This creates a call to hop_flex_attention with function symbol references for
19221922
score_mod and mask_mod.
19231923
"""
19241924
# flex_attention HOP args from PyTorch:
@@ -2035,7 +2035,7 @@ def _import_hop_flex_attention(
20352035
attributes["mask_mod_fn"] = mask_mod_ref
20362036

20372037
operation = Operation.create(
2038-
"torch.aten.flex_attention",
2038+
"torch.hop_flex_attention",
20392039
results=result_types,
20402040
operands=flat_operands,
20412041
attributes=attributes if attributes else None,

test/Dialect/Torch/ops.mlir

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -206,36 +206,72 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to
206206
return %1 : !torch.vtensor<[3,3],f32>
207207
}
208208

209-
// CHECK-LABEL: func.func @torch.aten.flex_attention
210-
func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
209+
210+
//===----------------------------------------------------------------------===//
211+
// FlexAttention variant tests
212+
//===----------------------------------------------------------------------===//
213+
214+
func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
215+
%5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
216+
return %5 : !torch.vtensor<[],f32>
217+
}
218+
219+
func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
220+
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
221+
return %0 : !torch.vtensor<[],i1>
222+
}
223+
224+
// CHECK-LABEL: func.func @torch.hop_flex_attention
225+
func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
211226
%float1.0 = torch.constant.float 1.000000e+00
212227
%false_0 = torch.constant.bool false
213228
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
214229
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
215-
// CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
230+
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
216231
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
217232
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
218233
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
219-
%output, %logsumexp, %maxscore = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
234+
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
220235
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
221236
}
222237

223-
func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
224-
%int1 = torch.constant.int 1
225-
%0 = torch.aten.sub.Tensor %arg3, %arg4, %int1 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.int -> !torch.vtensor<[],si32>
226-
%float1.000000e-01 = torch.constant.float 1.000000e-01
227-
%1 = torch.aten.mul.Scalar %arg2, %float1.000000e-01 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
228-
%float1.000000e-02 = torch.constant.float 1.000000e-02
229-
%2 = torch.aten.mul.Scalar %0, %float1.000000e-02 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
230-
%int1_0 = torch.constant.int 1
231-
%3 = torch.aten.add.Tensor %arg0, %2, %int1_0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
232-
%int1_1 = torch.constant.int 1
233-
%4 = torch.aten.add.Tensor %3, %1, %int1_1 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
234-
%5 = torch.aten.tanh %4 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
235-
return %5 : !torch.vtensor<[],f32>
238+
// CHECK-LABEL: func.func @torch.hop_flex_attention_nomask
239+
func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
240+
%float1.0 = torch.constant.float 1.000000e+00
241+
%false_0 = torch.constant.bool false
242+
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
243+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
244+
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
245+
// CHECK-SAME: {score_mod_fn = @sdpa_score0}
246+
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
247+
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
248+
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
249+
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
236250
}
237251

238-
func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
239-
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
240-
return %0 : !torch.vtensor<[],i1>
252+
// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore
253+
func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
254+
%float1.0 = torch.constant.float 1.000000e+00
255+
%false_0 = torch.constant.bool false
256+
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
257+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
258+
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
259+
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0}
260+
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
261+
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
262+
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
263+
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
264+
}
265+
266+
// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask
267+
func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
268+
%float1.0 = torch.constant.float 1.000000e+00
269+
%false_0 = torch.constant.bool false
270+
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
271+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
272+
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
273+
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
274+
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
275+
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
276+
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
241277
}

test/python/fx_importer/basic_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def body(i, x):
265265
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
266266
# CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool false
267267
# CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false
268-
# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.aten.flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
268+
# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.hop_flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
269269
# CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool
270270
# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32>
271271
# CHECK: return %[[OUTPUT]]

0 commit comments

Comments
 (0)