diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cc36ceeb953b..e2265c4974b6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12855,8 +12855,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { Value index = rewriter.create( loc, arangeType, end, cstNone, cstNone, cstNone, cstNone); - // Set the current dimension to -1 for broadcasting - viewShapeInts[dim] = -1; + viewShapeInts[dim] = size; viewShapeListElems[dim] = cstMinusOne; Value viewShapeList = rewriter.create( diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8fe502a7d686..24f8bf053b7b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -934,3 +934,19 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens %0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32> return %0 : !torch.vtensor<[1,8,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.as_strided$static_shapes +func.func @torch.aten.as_strided$static_shapes(%arg0: !torch.vtensor<[4,8],f32>) -> !torch.vtensor<[2,3],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + %size = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[2,1],si64> + // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[1,3],si64> + %0 = torch.aten.as_strided %arg0, %size, %stride, %int0 : !torch.vtensor<[4,8],f32>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +}