@@ -2124,6 +2124,152 @@ class ConvertAtenSliceScatterOp
21242124};
21252125} // namespace
21262126
2127+ namespace {
2128+ static Value
2129+ createLinalgPayloadForReduceScatterOp (OpBuilder &b, Location loc, Operation *op,
2130+ ValueRange payloadArgs, Value self,
2131+ int64_t dim, std::string reduceMode,
2132+ RankedTensorType resultType) {
2133+ Type resultElementType = resultType.getElementType ();
2134+ Value index = castIntToIndex (b, loc, /* abstractindexElement*/ payloadArgs[0 ]);
2135+ // Get the element at self[index].
2136+ auto selfElement = b.create <tensor::ExtractOp>(loc, self, ValueRange{index});
2137+ // Get the element at src[index].
2138+ Value srcElement = convertScalarToDtype (
2139+ b, loc, /* abstractSrcElement*/ payloadArgs[1 ], resultElementType);
2140+ Value accumulatorElement;
2141+ // Reduce the elements based on different mode.
2142+ // TODO add more reduce mode here.
2143+ if (reduceMode == " sum" ) {
2144+ // accumulatorElement = selfElement + srcElement;
2145+ if (isa<mlir::FloatType>(resultElementType))
2146+ accumulatorElement =
2147+ b.create <arith::AddFOp>(loc, selfElement, srcElement);
2148+ else if (isa<mlir::IntegerType>(resultElementType))
2149+ accumulatorElement =
2150+ b.create <arith::AddIOp>(loc, selfElement, srcElement);
2151+ } else {
2152+ op->emitError (" only sum lowering in createLinalgPayloadForReduceScatterOp" );
2153+ return nullptr ;
2154+ }
2155+ // Prepare source, indices, scatter_dims for scatter op.
2156+ Value accumulatorElementTensor =
2157+ b.create <tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158+ Value indexTensor = b.create <tensor::FromElementsOp>(loc, ValueRange{index});
2159+ ArrayRef<int64_t > dimArray{dim};
2160+ auto scatter = b.create <tensor::ScatterOp>(
2161+ loc, resultType, /* source*/ accumulatorElementTensor,
2162+ /* dest*/ self, /* indices*/ indexTensor,
2163+ /* scatter_dims*/ b.getDenseI64ArrayAttr (dimArray),
2164+ /* unique*/ b.getUnitAttr ());
2165+ return scatter;
2166+ }
2167+
2168+ class ConvertAtenScatterReduceOp
2169+ : public OpConversionPattern<AtenScatterReduceOp> {
2170+ public:
2171+ using OpConversionPattern::OpConversionPattern;
2172+ LogicalResult
2173+ matchAndRewrite (AtenScatterReduceOp op, OpAdaptor adaptor,
2174+ ConversionPatternRewriter &rewriter) const override {
2175+ Location loc = op.getLoc ();
2176+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
2177+ return failure ();
2178+
2179+ // Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180+ std::string reduceMode;
2181+ if (!matchPattern (op.getReduce (), m_TorchConstantStr (reduceMode)))
2182+ return rewriter.notifyMatchFailure (
2183+ op, " only support constant str reduce mode" );
2184+ // TODO: add "prod", "mean", "amax", "amin" mode.
2185+ if (reduceMode != " sum" )
2186+ return rewriter.notifyMatchFailure (
2187+ op, " Only support sum reduce mode for now" );
2188+
2189+ // Get dim.
2190+ int64_t dim;
2191+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
2192+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
2193+
2194+ // Prepare input.
2195+ auto self = adaptor.getSelf ();
2196+ auto selfType = cast<RankedTensorType>(self.getType ());
2197+ int64_t selfRank = selfType.getRank ();
2198+ // TODO: add more input rank support.
2199+ if (selfRank > 1 || dim > selfRank - 1 )
2200+ return rewriter.notifyMatchFailure (op,
2201+ " Only support self rank==1 for now" );
2202+
2203+ // Prepare index.
2204+ Value index = adaptor.getIndex ();
2205+ auto indexType = cast<RankedTensorType>(index.getType ());
2206+ int64_t indexRank = indexType.getRank ();
2207+ SmallVector<int64_t > indexAbstractSizes (indexRank, kUnknownSize );
2208+ auto abstractIndexType =
2209+ RankedTensorType::get (makeShapeLLVMCompatible (indexAbstractSizes),
2210+ indexType.getElementType ());
2211+ Value abstractindex =
2212+ rewriter.create <tensor::CastOp>(loc, abstractIndexType, index);
2213+
2214+ // Prepare src.
2215+ Value src = adaptor.getSrc ();
2216+ auto srcType = cast<RankedTensorType>(src.getType ());
2217+ int64_t srcRank = srcType.getRank ();
2218+ SmallVector<int64_t > srcAbstractSizes (srcRank, kUnknownSize );
2219+ auto abstractSrcType = RankedTensorType::get (
2220+ makeShapeLLVMCompatible (srcAbstractSizes), srcType.getElementType ());
2221+ Value abstractSrc =
2222+ rewriter.create <tensor::CastOp>(loc, abstractSrcType, src);
2223+
2224+ // Prepare result type.
2225+ const TypeConverter *typeConverter = getTypeConverter ();
2226+ RankedTensorType resultType = cast<RankedTensorType>(
2227+ typeConverter->convertType (op->getResult (0 ).getType ()));
2228+
2229+ // Prepare indexingMaps and iteratorTypes.
2230+ SmallVector<AffineMap, 3 > indexingMaps = {
2231+ rewriter.getMultiDimIdentityMap (indexRank),
2232+ rewriter.getMultiDimIdentityMap (srcRank),
2233+ rewriter.getMultiDimIdentityMap (selfRank),
2234+ };
2235+ // Prepare iteratorTypes.
2236+ SmallVector<utils::IteratorType> iteratorTypes{
2237+ 1 , utils::IteratorType::parallel};
2238+
2239+ // Implementation of scatter and reduce in linalg.generic.
2240+ bool err = false ;
2241+ Value result =
2242+ rewriter
2243+ .create <linalg::GenericOp>(
2244+ loc, /* resultTensorTypes=*/ self.getType (),
2245+ /* inputs=*/ ValueRange ({abstractindex, abstractSrc}),
2246+ /* outputs=*/ self, indexingMaps, iteratorTypes,
2247+ [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248+ // Scatter result after reduce accumulation.
2249+ Value scatter = createLinalgPayloadForReduceScatterOp (
2250+ builder, loc, op, payloadArgs, self, dim, reduceMode,
2251+ resultType);
2252+ // Return selfElements to itself, nothing change but a
2253+ // placeholder.
2254+ if (scatter) {
2255+ builder.create <linalg::YieldOp>(
2256+ loc, /* selfElement*/ payloadArgs[2 ]);
2257+ }
2258+ err = !scatter;
2259+ })
2260+ .getResult (0 );
2261+
2262+ if (err)
2263+ return rewriter.notifyMatchFailure (
2264+ op,
2265+ " failed to create linalg.generic operation for reduce scatter op" );
2266+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
2267+
2268+ return success ();
2269+ }
2270+ };
2271+ } // namespace
2272+
21272273namespace {
21282274class ConvertAtenViewAsComplexOp
21292275 : public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2664,6 +2810,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
26642810 patterns.add <ConvertAtenCopyOp>(typeConverter, context);
26652811 target.addIllegalOp <AtenSliceScatterOp>();
26662812 patterns.add <ConvertAtenSliceScatterOp>(typeConverter, context);
2813+ target.addIllegalOp <AtenScatterReduceOp>();
2814+ patterns.add <ConvertAtenScatterReduceOp>(typeConverter, context);
26672815 target.addIllegalOp <AtenViewAsComplexOp>();
26682816 patterns.add <ConvertAtenViewAsComplexOp>(typeConverter, context);
26692817 target.addIllegalOp <AtenViewAsRealOp>();
0 commit comments