Skip to content

Commit 68504aa

Browse files
Added _sdpa_flash_attention op
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 244f4b6 commit 68504aa

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15193,6 +15193,43 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at
1519315193
}];
1519415194
}
1519515195

15196+
def Torch_AtenScaledDotProductFlashAttentionOp : Torch_Op<"aten._scaled_dot_product_flash_attention", [
15197+
AllowsTypeRefinement,
15198+
HasValueSemantics,
15199+
ReadOnly
15200+
]> {
15201+
let summary = "Generated op for `aten::_scaled_dot_product_flash_attention(Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, SymInt, SymInt, Tensor, Tensor, Tensor)`";
15202+
let arguments = (ins
15203+
AnyTorchTensorType:$query,
15204+
AnyTorchTensorType:$key,
15205+
AnyTorchTensorType:$value,
15206+
Torch_FloatType:$dropout_p,
15207+
Torch_BoolType:$is_causal,
15208+
Torch_BoolType:$return_debug_mask,
15209+
AnyTorchOptionalFloatType:$scale
15210+
);
15211+
let results = (outs
15212+
AnyTorchOptionalTensorType:$output,
15213+
AnyTorchOptionalTensorType:$logsumexp,
15214+
AnyTorchOptionalTensorType:$cum_seq_q,
15215+
AnyTorchOptionalTensorType:$cum_seq_k,
15216+
Torch_IntType:$max_q,
15217+
Torch_IntType:$max_k,
15218+
AnyTorchOptionalTensorType:$rng_state,
15219+
AnyTorchOptionalTensorType:$unused,
15220+
AnyTorchOptionalTensorType:$debug_attn_mask
15221+
);
15222+
let hasCustomAssemblyFormat = 1;
15223+
let extraClassDefinition = [{
15224+
ParseResult AtenScaledDotProductFlashAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
15225+
return parseDefaultTorchOp(parser, result, 7, 9);
15226+
}
15227+
void AtenScaledDotProductFlashAttentionOp::print(OpAsmPrinter &printer) {
15228+
printDefaultTorchOp(printer, *this, 7, 9);
15229+
}
15230+
}];
15231+
}
15232+
1519615233
def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
1519715234
AllowsTypeRefinement,
1519815235
HasValueSemantics,

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"int[]": "AnyTorchListOfTorchIntType",
3030
"int?": "AnyTorchOptionalIntType",
3131
"int[]?": "AnyTorchOptionalListOfTorchIntType",
32+
"SymInt": "Torch_IntType",
3233
"bool": "Torch_BoolType",
3334
"bool[]": "AnyTorchListOfTorchBoolType",
3435
"bool?": "AnyTorchOptionalBoolType",
@@ -1087,9 +1088,14 @@ def emit_with_mutating_variants(key, **kwargs):
10871088
"aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)"
10881089
)
10891090
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")
1091+
1092+
# Attention ops.
10901093
emit(
10911094
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
10921095
)
1096+
emit(
1097+
"aten::_scaled_dot_product_flash_attention(Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, SymInt, SymInt, Tensor, Tensor, Tensor)"
1098+
)
10931099
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
10941100
emit(
10951101
"aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"

0 commit comments

Comments
 (0)