@@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
5252
5353 // Max pooling
5454 if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
55- AtenMaxPool2dWithIndicesOp>(op)) {
55+ AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
5656 if (isa<mlir::FloatType>(elementTy)) {
5757 auto constAttr = DenseElementsAttr::get (
5858 constType,
@@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
7373 return nullptr ;
7474}
7575
76+ // AtenMaxPool1dWithIndicesOp
77+ template <>
78+ LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
79+ AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
80+ ConversionPatternRewriter &rewriter) const {
81+ Value input = adaptor.getSelf ();
82+ auto inputTy = cast<RankedTensorType>(input.getType ());
83+ auto inputElemTy = inputTy.getElementType ();
84+ auto inputShape = inputTy.getShape ();
85+ auto inputRank = inputTy.getRank ();
86+
87+ auto outValTy =
88+ cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType (0 )));
89+ auto outIdxTy =
90+ cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType (1 )));
91+
92+ if (inputRank <= 1 ) {
93+ return op.emitError (
94+ " max_pooling1d only supports inputs with rank higher than 1" );
95+ }
96+
97+ SmallVector<int64_t , 1 > padding, kernelSize, stride, dilation;
98+ bool ceilMode = false ;
99+
100+ if (!(matchPattern (op.getKernelSize (),
101+ m_TorchListOfConstantInts (kernelSize)))) {
102+ return rewriter.notifyMatchFailure (
103+ op, " non-const int kernel size unsupported!" );
104+ }
105+ if (!(matchPattern (op.getStride (), m_TorchListOfConstantInts (stride)))) {
106+ return rewriter.notifyMatchFailure (op, " non-const int stride unsupported!" );
107+ }
108+ if (!(matchPattern (op.getPadding (), m_TorchListOfConstantInts (padding)))) {
109+ return rewriter.notifyMatchFailure (op,
110+ " non-const int padding unsupported!" );
111+ }
112+ if (!(matchPattern (op.getDilation (), m_TorchListOfConstantInts (dilation)))) {
113+ return rewriter.notifyMatchFailure (op,
114+ " non-const int dilation unsupported!" );
115+ }
116+ if (!(matchPattern (op.getCeilMode (), m_TorchConstantBool (&ceilMode)))) {
117+ return rewriter.notifyMatchFailure (op,
118+ " non-const bool ceil_mode unsupported!" );
119+ }
120+
121+ SmallVector<int64_t > stablehloStride (inputRank, 1 );
122+ SmallVector<int64_t > stablehloDilation (inputRank, 1 );
123+ SmallVector<int64_t > stablehloKernelSize (inputRank, 1 );
124+ SmallVector<int64_t > stablehloPadding (inputRank * 2 , 0 );
125+
126+ std::copy (stride.begin (), stride.end (),
127+ stablehloStride.begin () + inputRank - 1 );
128+ std::copy (dilation.begin (), dilation.end (),
129+ stablehloDilation.begin () + inputRank - 1 );
130+ std::copy (kernelSize.begin (), kernelSize.end (),
131+ stablehloKernelSize.begin () + inputRank - 1 );
132+ stablehloPadding[stablehloPadding.size () - 1 ] = padding[0 ];
133+ stablehloPadding[stablehloPadding.size () - 2 ] = padding[0 ];
134+
135+ Value initVal = createInitialValueForAtenPoolingOp (op, inputElemTy, rewriter);
136+
137+ auto windowDimensions = rewriter.getDenseI64ArrayAttr (stablehloKernelSize);
138+ auto windowStrides = rewriter.getDenseI64ArrayAttr (stablehloStride);
139+ auto windowDilations = rewriter.getDenseI64ArrayAttr (stablehloDilation);
140+ DenseIntElementsAttr pad = DenseIntElementsAttr::get (
141+ RankedTensorType::get (
142+ {static_cast <int64_t >(inputRank), static_cast <int64_t >(2 )},
143+ rewriter.getI64Type ()),
144+ stablehloPadding);
145+ DenseI64ArrayAttr baseDilations;
146+
147+ auto inputShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
148+ if (failed (inputShapeInfo)) {
149+ return rewriter.notifyMatchFailure (
150+ op, " failed to get dimension sizes of the input" );
151+ }
152+ auto inputShapeVec = *inputShapeInfo;
153+ auto inputShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
154+ op->getLoc (), inputShapeVec);
155+
156+ // no need to reshape here for max_pool_1d. Need to make sure the iota
157+ // dimension. dim=inputRank-2 or dim=inputRank-1?
158+ auto indexTensor =
159+ rewriter
160+ .create <stablehlo::DynamicIotaOp>(
161+ op->getLoc (),
162+ RankedTensorType::get (inputShape, rewriter.getI64Type ()),
163+ inputShapeTensor, static_cast <uint64_t >(inputRank - 1 ))
164+ .getResult ();
165+ Value initIdx = hlo::getConstTensor<int64_t >(rewriter, op, {0 }, {}).value ();
166+
167+ auto reduceWindowOp = rewriter.create <stablehlo::ReduceWindowOp>(
168+ op->getLoc (), mlir::TypeRange{outValTy, outIdxTy},
169+ mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
170+ windowDimensions, windowStrides, baseDilations, windowDilations, pad);
171+
172+ // add block.
173+ Block &block = reduceWindowOp.getBody ().emplaceBlock ();
174+ auto blockValArgumentType = RankedTensorType::get ({}, inputElemTy);
175+ auto blockIdxArgumentType = RankedTensorType::get ({}, rewriter.getI64Type ());
176+ auto compareResultType = RankedTensorType::get ({}, rewriter.getI1Type ());
177+ block.addArgument (blockValArgumentType, op->getLoc ());
178+ block.addArgument (blockIdxArgumentType, op->getLoc ());
179+ block.addArgument (blockValArgumentType, op->getLoc ());
180+ block.addArgument (blockIdxArgumentType, op->getLoc ());
181+ auto *firstValArg = block.args_begin ();
182+ auto *firstIdxArg = std::next (firstValArg);
183+ auto *secondValArg = std::next (firstIdxArg);
184+ auto *secondIdxArg = std::next (secondValArg);
185+
186+ stablehlo::ComparisonTypeAttr compareTypeAttr;
187+ if (isa<mlir::FloatType>(inputTy.getElementType ())) {
188+ compareTypeAttr = stablehlo::ComparisonTypeAttr::get (
189+ rewriter.getContext (), stablehlo::ComparisonType::FLOAT);
190+ } else if (isa<mlir::IntegerType>(inputTy.getElementType ())) {
191+ compareTypeAttr = stablehlo::ComparisonTypeAttr::get (
192+ rewriter.getContext (), stablehlo::ComparisonType::SIGNED);
193+ }
194+
195+ stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
196+ stablehlo::ComparisonDirectionAttr::get (
197+ rewriter.getContext (), stablehlo::ComparisonDirection::GE);
198+ stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
199+ stablehlo::ComparisonDirectionAttr::get (
200+ rewriter.getContext (), stablehlo::ComparisonDirection::EQ);
201+
202+ {
203+ OpBuilder::InsertionGuard guard (rewriter);
204+ rewriter.setInsertionPointToStart (&block);
205+
206+ Value compareGeResult = rewriter.create <stablehlo::CompareOp>(
207+ op->getLoc (), compareResultType, *firstValArg, *secondValArg,
208+ compareGeDirectionAttr, compareTypeAttr);
209+ Value retValResult = rewriter.create <stablehlo::SelectOp>(
210+ op->getLoc (), compareGeResult, *firstValArg, *secondValArg);
211+
212+ // Get smaller index if compared values are equal.
213+ Value compareEqResult = rewriter.create <stablehlo::CompareOp>(
214+ op->getLoc (), compareResultType, *firstValArg, *secondValArg,
215+ compareEqDirectionAttr, compareTypeAttr);
216+ Value minIdx = rewriter.create <stablehlo::MinOp>(op->getLoc (), *firstIdxArg,
217+ *secondIdxArg);
218+ Value idxWithGeVal = rewriter.create <stablehlo::SelectOp>(
219+ op->getLoc (), compareGeResult, *firstIdxArg, *secondIdxArg);
220+ Value retIdxResult = rewriter.create <stablehlo::SelectOp>(
221+ op->getLoc (), compareEqResult, minIdx, idxWithGeVal);
222+
223+ rewriter.create <stablehlo::ReturnOp>(
224+ op->getLoc (), mlir::ValueRange{retValResult, retIdxResult});
225+ }
226+
227+ rewriter.replaceOp (op, reduceWindowOp.getResults ());
228+ return success ();
229+ }
230+
76231// AtenMaxPool2dWithIndicesOp
77232template <>
78233LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
657812#define INSERT_ATEN_POOLING_PATTERN (AtenOp ) \
658813 target.addIllegalOp <AtenOp>(); \
659814 patterns.add <ConvertAtenOp<AtenOp>>(typeConverter, context, options)
815+ INSERT_ATEN_POOLING_PATTERN (AtenMaxPool1dWithIndicesOp);
660816 INSERT_ATEN_POOLING_PATTERN (AtenMaxPool2dWithIndicesOp);
661817 INSERT_ATEN_POOLING_PATTERN (AtenCumsumOp);
662818#undef INSERT_ATEN_POOLING_PATTERN
0 commit comments