diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6f02a94768d0..2472689daa30 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13668,6 +13668,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [ }]; } +def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dimension, + Torch_IntType:$size, + Torch_IntType:$step + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUnfoldOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc642f..5b4d6382b2bd 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2573,6 +2573,167 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenUnfoldOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto self = adaptor.getSelf(); + RankedTensorType selfType = cast(self.getType()); + + int64_t dimension; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) { + return rewriter.notifyMatchFailure(op, + "only support constant int dimension"); + } + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) { + return rewriter.notifyMatchFailure(op, "only support constant int size"); + } + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { + return rewriter.notifyMatchFailure(op, "only support constant int step"); + } + + if (step <= 0) { + return rewriter.notifyMatchFailure(op, "step must be greater than zero."); + } + + int64_t selfRank = selfType.getRank(); + + // Zero-Rank case + if (selfRank == 0) { + // Empty tensor + if (size == 0) { + RankedTensorType resultType = + RankedTensorType::get({0}, selfType.getElementType()); + Value emptyTensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + rewriter.replaceOp(op, emptyTensor); + return success(); + } + + Value unsqueezedSelf = rewriter.create( + loc, RankedTensorType::get({1}, selfType.getElementType()), self, + ArrayRef{}); + rewriter.replaceOp(op, unsqueezedSelf); + return success(); + } + + auto shape = selfType.getShape(); + + if (dimension < 0) { + dimension = toPositiveDim(dimension, selfRank); + } + if (!isValidDim(dimension, selfRank)) { + return rewriter.notifyMatchFailure(op, "dimension out of range"); + } + + Value dimSize = rewriter.create(loc, self, dimension); + + Value sizeValue = rewriter.create(loc, size); + Value sizeCheck = rewriter.create( + loc, arith::CmpIPredicate::ule, sizeValue, dimSize); + rewriter.create( + loc, sizeCheck, + rewriter.getStringAttr("size must be <= target dimension")); + + /* Calculate output shape of unfold op: + * https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html + * outputShape[dimension] is set to numBlocks, with size appended as an + * additional dimension + */ + SmallVector outputShape; + for (int64_t i = 0; i < selfRank; i++) { + if (i == dimension) { + outputShape.push_back(getDynamicOrStaticNumBlocks( + rewriter, loc, shape[dimension], dimSize, size, step)); + } else if (shape[i] == ShapedType::kDynamic) { + outputShape.push_back( + OpFoldResult(rewriter.create(loc, self, i))); + } else { + outputShape.push_back(rewriter.getIndexAttr(shape[i])); + } + } + outputShape.push_back(rewriter.getIndexAttr(size)); + + // Empty tensor to insert values into + Value outputTensor = rewriter.create( + loc, outputShape, selfType.getElementType()); + + /** + * Use reindexing to map output indices to input indices + * i.e. In output of rank 3 case: + * (i, j, k) => (i', j') where i' = i * step + k and j' = j + * if dimension == 0 + * (i, j, k) => (i', j') where i' = i and j' = j * step + k + * if dimension == 1 + */ + MLIRContext *context = rewriter.getContext(); + SmallVector outputExprs; + for (int dim = 0; dim < selfRank; ++dim) { + if (dim == dimension) { + auto idxLast = getAffineDimExpr(selfRank, context); + auto idxDimension = getAffineDimExpr(dimension, context); + + AffineExpr dimIdx = + idxLast + idxDimension * rewriter.getAffineConstantExpr(step); + outputExprs.push_back(dimIdx); + } else { + outputExprs.push_back(getAffineDimExpr(dim, context)); + } + } + + int64_t outputRank = selfRank + 1; + auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context); + auto outputAffineMap = + AffineMap::getMultiDimIdentityMap(outputRank, context); + + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + + Value result = + rewriter + .create( + loc, outputTensor.getType(), self, outputTensor, + ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc, + int64_t shapeDim, Value dimSize, + int64_t size, int64_t step) const { + /** + * numBlocks = (shape[dimension] - size) // step + 1 + */ + if (shapeDim == ShapedType::kDynamic) { + Value numBlocksSubOp = rewriter.create( + loc, dimSize, rewriter.create(loc, size)); + Value numBlocksDivOp = rewriter.create( + loc, numBlocksSubOp, + rewriter.create(loc, step)); + Value numBlocks = rewriter.create( + loc, rewriter.create(loc, 1), numBlocksDivOp); + return OpFoldResult(numBlocks); + } + + int64_t staticNumBlocks = (shapeDim - size) / step + 1; + return rewriter.getIndexAttr(staticNumBlocks); // Use static value + } +}; +} // namespace + namespace { class ConvertSparseOperatorOp : public OpConversionPattern { public: @@ -2641,7 +2802,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( /*benefit=*/200); patterns.add(typeConverter, context, /*benefit=*/100); - + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 445d4e459013..559726f20659 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -15588,6 +15588,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"size must be less than or equal to {}\"\n" +" %false = torch.constant.bool false\n" +" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: \"\n" +" %str_2 = torch.constant.str \"dimension out of range of {}\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n" +" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n" +" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n" +" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %14 = torch.aten.append.t %12, %arg2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %12 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 3d842f44aee0..664bbb2d5d8e 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -278,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) { AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, - AtenPixelShuffleOp, AtenDiagonalOp>(op); + AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2852611fe01b..ec57942e6069 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -915,6 +915,11 @@ "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "Unfold_Module_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Dynamic_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3127,6 +3132,10 @@ "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "UnfoldModule_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Dynamic_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d3ec25bcea70..2b7db059bb42 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5559,7 +5559,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, return torch.qint8 return torch.qint32 +@check_shape_function([ + Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero. + Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0. + Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case. + Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case. + Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension. + Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension. +]) +def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]: + ndim = len(self) + + # Rank zero tensor + if ndim == 0: + assert dimension == 0, f"dimension out of range of {ndim}" + assert size <= 1, "size must be less than or equal to 1" + return [size] + + dim = dimension + if dim < 0: + dim += ndim + + assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}" + size_dim = self[dim] + assert size <= size_dim, f"size must be less than or equal to {size_dim}" + + num_blocks = (size_dim - size) // step + 1 + + out = upstream_shape_functions._copy(self) + out[dim] = num_blocks + out.append(size) + return out + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1) +) +def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ea5c504284eb..72f09cc56d11 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -991,6 +991,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)") emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") + emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)") emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 5524b2a79bf1..ee9cbbf05888 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1648,3 +1648,103 @@ def forward(self, a): @register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule()) def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 5, 1, 7, 3)) + + +class Unfold_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 2, 2) + + +@register_test_case(module_factory=lambda: Unfold_Module()) +def Unfold_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4)) + + +class Unfold_Module_Negative_Dim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(-1, 2, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim()) +def Unfold_Module_Rank_4(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 4, 4)) + + +class Unfold_Module_Rank_Zero(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 1, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 0, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Unfold_Module_Dynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(1, 2, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Dynamic()) +def Unfold_Module_Dynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 4, 4))