Skip to content

Commit d633aa6

Browse files
IntType -> SymInt && SymInt == AnyTorchScalarType
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 68504aa commit d633aa6

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15213,8 +15213,8 @@ def Torch_AtenScaledDotProductFlashAttentionOp : Torch_Op<"aten._scaled_dot_prod
1521315213
AnyTorchOptionalTensorType:$logsumexp,
1521415214
AnyTorchOptionalTensorType:$cum_seq_q,
1521515215
AnyTorchOptionalTensorType:$cum_seq_k,
15216-
Torch_IntType:$max_q,
15217-
Torch_IntType:$max_k,
15216+
Torch_SymIntType:$max_q,
15217+
Torch_SymIntType:$max_k,
1521815218
AnyTorchOptionalTensorType:$rng_state,
1521915219
AnyTorchOptionalTensorType:$unused,
1522015220
AnyTorchOptionalTensorType:$debug_attn_mask

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"int[]": "AnyTorchListOfTorchIntType",
3030
"int?": "AnyTorchOptionalIntType",
3131
"int[]?": "AnyTorchOptionalListOfTorchIntType",
32-
"SymInt": "Torch_IntType",
32+
"SymInt": "AnyTorchScalarType",
3333
"bool": "Torch_BoolType",
3434
"bool[]": "AnyTorchListOfTorchBoolType",
3535
"bool?": "AnyTorchOptionalBoolType",

0 commit comments

Comments
 (0)