@@ -596,37 +596,51 @@ namespace {
596596// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
597597// What worse, without knowing kernel size we cannot even reliably detect such
598598// cases and this conversion will just return invalid values.
599- class ConvertAtenMaxUnpool3dOp final
600- : public OpConversionPattern<AtenMaxUnpool3dOp> {
601- public:
602- using OpConversionPattern::OpConversionPattern;
603- LogicalResult
604- matchAndRewrite (AtenMaxUnpool3dOp op, OpAdaptor adaptor,
605- ConversionPatternRewriter &rewriter) const override {
599+
600+ template <> struct DimensionTraits <AtenMaxUnpool2dOp> {
601+ static constexpr int64_t Dim = 2 ;
602+ // unused const variable warning suppression:
603+ static_assert (Dim == Dim);
604+ };
605+
606+ template <> struct DimensionTraits <AtenMaxUnpool3dOp> {
607+ static constexpr int64_t Dim = 3 ;
608+ // unused const variable warning suppression:
609+ static_assert (Dim == Dim);
610+ };
611+
612+ template <typename OpTy>
613+ class ConvertAtenMaxUnpoolOp : public OpConversionPattern <OpTy> {
614+ using OpConversionPattern<OpTy>::OpConversionPattern;
615+
616+ private:
617+ static const int64_t Dim = DimensionTraits<OpTy>::Dim;
618+
619+ LogicalResult createUnpoolOp (OpTy &op, typename OpTy::Adaptor adaptor,
620+ ConversionPatternRewriter &rewriter) const {
606621 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
607622 return failure ();
608623
609624 Location loc = op->getLoc ();
610- const TypeConverter *typeConverter = getTypeConverter ();
625+ const TypeConverter *typeConverter = this -> getTypeConverter ();
611626 Value self = adaptor.getSelf ();
612627 auto selfType = cast<RankedTensorType>(self.getType ());
613628
614- size_t spatial = selfType.getRank () - 2 ;
615- ArrayRef<int64_t > inputSize = selfType.getShape ().take_back (spatial);
629+ ArrayRef<int64_t > inputSize = selfType.getShape ().take_back (Dim);
616630 if (ShapedType::isDynamicShape (inputSize))
617631 return rewriter.notifyMatchFailure (op,
618632 " input type must be of static shape" );
619633
620634 Value indices = adaptor.getIndices ();
621635 auto indicesType = cast<RankedTensorType>(indices.getType ());
622- if (inputSize != indicesType.getShape ().take_back (spatial ))
636+ if (inputSize != indicesType.getShape ().take_back (Dim ))
623637 return rewriter.notifyMatchFailure (op, " input/indices shape mismatch" );
624638
625639 auto resType = typeConverter->convertType <RankedTensorType>(op.getType ());
626640 if (!resType)
627641 return rewriter.notifyMatchFailure (op, " invalid result type" );
628642
629- ArrayRef<int64_t > inferredOutSize = resType.getShape ().take_back (spatial );
643+ ArrayRef<int64_t > inferredOutSize = resType.getShape ().take_back (Dim );
630644 if (ShapedType::isDynamicShape (inferredOutSize))
631645 return rewriter.notifyMatchFailure (op,
632646 " output type must be of static shape" );
@@ -637,7 +651,7 @@ class ConvertAtenMaxUnpool3dOp final
637651 return rewriter.notifyMatchFailure (op,
638652 " only support constant int output" );
639653
640- if (inferredOutSize != ArrayRef (output).take_back (spatial ))
654+ if (inferredOutSize != ArrayRef (output).take_back (Dim ))
641655 return rewriter.notifyMatchFailure (op, " Invalid output size" );
642656 }
643657 SmallVector<int64_t > stride;
@@ -653,12 +667,12 @@ class ConvertAtenMaxUnpool3dOp final
653667
654668 // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
655669 // (padding.size() == 6).
656- if (stride.size () != spatial || padding.size () != spatial )
670+ if (stride.size () != Dim || padding.size () != Dim )
657671 return rewriter.notifyMatchFailure (
658- op, " stride and padding must be of size 3 " );
672+ op, " stride and padding must be of size Dim " );
659673
660674 int64_t outRank = resType.getRank ();
661- int64_t NC = outRank - spatial ;
675+ int64_t NC = outRank - Dim ;
662676
663677 for (auto &&[inDim, outDim, str, pad] :
664678 llvm::zip_equal (inputSize, inferredOutSize, stride, padding)) {
@@ -695,7 +709,7 @@ class ConvertAtenMaxUnpool3dOp final
695709 // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
696710 // pad self and indices tensors to avoid out of bounds access.
697711 SmallVector<int64_t > expectedInputShape =
698- llvm::to_vector (resType.getShape ().drop_back (spatial ));
712+ llvm::to_vector (resType.getShape ().drop_back (Dim ));
699713 for (auto &&[str, pad, resSize] :
700714 llvm::zip_equal (stride, padding, inferredOutSize))
701715 expectedInputShape.emplace_back (ceilDiv (resSize, str) + pad * 2 );
@@ -708,7 +722,7 @@ class ConvertAtenMaxUnpool3dOp final
708722 SmallVector<int64_t > low (outRank, 0 );
709723 SmallVector<int64_t > high (NC, 0 );
710724 for (auto &&[inpSize, outSize] : llvm::zip_equal (
711- inputSize, ArrayRef (expectedInputShape).take_back (spatial ))) {
725+ inputSize, ArrayRef (expectedInputShape).take_back (Dim ))) {
712726 high.emplace_back (outSize - inpSize);
713727 }
714728
@@ -827,6 +841,13 @@ class ConvertAtenMaxUnpool3dOp final
827841 rewriter.replaceOp (op, result);
828842 return success ();
829843 }
844+
845+ public:
846+ LogicalResult
847+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
848+ ConversionPatternRewriter &rewriter) const override {
849+ return createUnpoolOp (op, adaptor, rewriter);
850+ }
830851};
831852} // namespace
832853
@@ -1527,8 +1548,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
15271548 patterns.add <ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
15281549 context);
15291550
1551+ target.addIllegalOp <AtenMaxUnpool2dOp>();
15301552 target.addIllegalOp <AtenMaxUnpool3dOp>();
1531- patterns.add <ConvertAtenMaxUnpool3dOp>(typeConverter, context);
1553+ patterns.add <ConvertAtenMaxUnpoolOp<AtenMaxUnpool2dOp>>(typeConverter,
1554+ context);
1555+ patterns.add <ConvertAtenMaxUnpoolOp<AtenMaxUnpool3dOp>>(typeConverter,
1556+ context);
15321557
15331558 target.addIllegalOp <AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
15341559 patterns
0 commit comments