From b68730d6f5d9c5fe5a369bc8f6bf1a2a275c6214 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Tue, 7 Oct 2025 19:03:28 +0000 Subject: [PATCH 1/3] set viewShapeInts[dim] = size; in DecomposeComplexOps --- .../Torch/Transforms/DecomposeComplexOps.cpp | 3 +- .../test_suite/reshape_like.py | 40 ++++++++++ test/Conversion/TorchToLinalg/unflatten.mlir | 74 +++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/unflatten.mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cc36ceeb953b..e2265c4974b6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12855,8 +12855,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { Value index = rewriter.create( loc, arangeType, end, cstNone, cstNone, cstNone, cstNone); - // Set the current dimension to -1 for broadcasting - viewShapeInts[dim] = -1; + viewShapeInts[dim] = size; viewShapeListElems[dim] = cstMinusOne; Value viewShapeList = rewriter.create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..1441eb1890f7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1281,6 +1281,46 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) +class UnflattenIntDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 12], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [3, 4]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicModule()) +def UnflattenIntDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 12)) + + +class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 20], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [4, -1]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule()) +def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/unflatten.mlir b/test/Conversion/TorchToLinalg/unflatten.mlir new file mode 100644 index 000000000000..01049d4fac29 --- /dev/null +++ b/test/Conversion/TorchToLinalg/unflatten.mlir @@ -0,0 +1,74 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$static +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int-2 = torch.constant.int -2 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[3,2,6],f32> + return %1 : !torch.vtensor<[3,2,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,2,3],f32> + return %1 : !torch.vtensor<[?,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.from_elements +// CHECK: tensor.reshape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int + %0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %1 : !torch.vtensor<[?,?,?],f32> +} From d83ae9a625d92576322139dc0673ec8db7aa3545 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Tue, 7 Oct 2025 19:36:42 +0000 Subject: [PATCH 2/3] add a test for aten.as_strided in decompose-complex-ops --- test/Dialect/Torch/decompose-complex-ops.mlir | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8fe502a7d686..24f8bf053b7b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -934,3 +934,19 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens %0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32> return %0 : !torch.vtensor<[1,8,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.as_strided$static_shapes +func.func @torch.aten.as_strided$static_shapes(%arg0: !torch.vtensor<[4,8],f32>) -> !torch.vtensor<[2,3],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + %size = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[2,1],si64> + // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[1,3],si64> + %0 = torch.aten.as_strided %arg0, %size, %stride, %int0 : !torch.vtensor<[4,8],f32>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} From 95d4bbf106c1b6019298bc2378ad407e946a37d7 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Tue, 7 Oct 2025 23:52:17 +0000 Subject: [PATCH 3/3] remove unflatten tests --- .../test_suite/reshape_like.py | 40 ---------- test/Conversion/TorchToLinalg/unflatten.mlir | 74 ------------------- 2 files changed, 114 deletions(-) delete mode 100644 test/Conversion/TorchToLinalg/unflatten.mlir diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1441eb1890f7..d1ddc42b39b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1281,46 +1281,6 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) -class UnflattenIntDynamicModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, 12], torch.float32, True), - ] - ) - def forward(self, inputs): - return torch.ops.aten.unflatten(inputs, 1, [3, 4]) - - -@register_test_case(module_factory=lambda: UnflattenIntDynamicModule()) -def UnflattenIntDynamicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 12)) - - -class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, 20], torch.float32, True), - ] - ) - def forward(self, inputs): - return torch.ops.aten.unflatten(inputs, 1, [4, -1]) - - -@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule()) -def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 20)) - - # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/unflatten.mlir b/test/Conversion/TorchToLinalg/unflatten.mlir deleted file mode 100644 index 01049d4fac29..000000000000 --- a/test/Conversion/TorchToLinalg/unflatten.mlir +++ /dev/null @@ -1,74 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.unflatten.int$static -// CHECK: torch_c.to_builtin_tensor -// CHECK: tensor.expand_shape -// CHECK: torch_c.from_builtin_tensor -func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> - return %1 : !torch.vtensor<[2,2,3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim -// CHECK: torch_c.to_builtin_tensor -// CHECK: tensor.expand_shape -// CHECK: torch_c.from_builtin_tensor -func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { - %int-2 = torch.constant.int -2 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> - return %1 : !torch.vtensor<[2,2,3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size -// CHECK: torch_c.to_builtin_tensor -// CHECK: tensor.expand_shape -// CHECK: torch_c.from_builtin_tensor -func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int-1 = torch.constant.int -1 - %0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[3,2,6],f32> - return %1 : !torch.vtensor<[3,2,6],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input -// CHECK: torch_c.to_builtin_tensor -// CHECK: tensor.expand_shape -// CHECK: torch_c.from_builtin_tensor -func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,2,3],f32> - return %1 : !torch.vtensor<[?,2,3],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims -// CHECK: torch_c.to_builtin_tensor -// CHECK: tensor.from_elements -// CHECK: tensor.reshape -// CHECK: torch_c.from_builtin_tensor -func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> { - %int1 = torch.constant.int 1 - %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int - %0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,?],f32> - return %1 : !torch.vtensor<[?,?,?],f32> -}