@@ -2124,152 +2124,6 @@ class ConvertAtenSliceScatterOp
21242124};
21252125} // namespace
21262126
2127- namespace {
2128- static Value
2129- createLinalgPayloadForReduceScatterOp (OpBuilder &b, Location loc, Operation *op,
2130- ValueRange payloadArgs, Value self,
2131- int64_t dim, std::string reduceMode,
2132- RankedTensorType resultType) {
2133- Type resultElementType = resultType.getElementType ();
2134- Value index = castIntToIndex (b, loc, /* abstractindexElement*/ payloadArgs[0 ]);
2135- // Get the element at self[index].
2136- auto selfElement = b.create <tensor::ExtractOp>(loc, self, ValueRange{index});
2137- // Get the element at src[index].
2138- Value srcElement = convertScalarToDtype (
2139- b, loc, /* abstractSrcElement*/ payloadArgs[1 ], resultElementType);
2140- Value accumulatorElement;
2141- // Reduce the elements based on different mode.
2142- // TODO add more reduce mode here.
2143- if (reduceMode == " sum" ) {
2144- // accumulatorElement = selfElement + srcElement;
2145- if (isa<mlir::FloatType>(resultElementType))
2146- accumulatorElement =
2147- b.create <arith::AddFOp>(loc, selfElement, srcElement);
2148- else if (isa<mlir::IntegerType>(resultElementType))
2149- accumulatorElement =
2150- b.create <arith::AddIOp>(loc, selfElement, srcElement);
2151- } else {
2152- op->emitError (" only sum lowering in createLinalgPayloadForReduceScatterOp" );
2153- return nullptr ;
2154- }
2155- // Prepare source, indices, scatter_dims for scatter op.
2156- Value accumulatorElementTensor =
2157- b.create <tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158- Value indexTensor = b.create <tensor::FromElementsOp>(loc, ValueRange{index});
2159- ArrayRef<int64_t > dimArray{dim};
2160- auto scatter = b.create <tensor::ScatterOp>(
2161- loc, resultType, /* source*/ accumulatorElementTensor,
2162- /* dest*/ self, /* indices*/ indexTensor,
2163- /* scatter_dims*/ b.getDenseI64ArrayAttr (dimArray),
2164- /* unique*/ b.getUnitAttr ());
2165- return scatter;
2166- }
2167-
2168- class ConvertAtenScatterReduceOp
2169- : public OpConversionPattern<AtenScatterReduceOp> {
2170- public:
2171- using OpConversionPattern::OpConversionPattern;
2172- LogicalResult
2173- matchAndRewrite (AtenScatterReduceOp op, OpAdaptor adaptor,
2174- ConversionPatternRewriter &rewriter) const override {
2175- Location loc = op.getLoc ();
2176- if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
2177- return failure ();
2178-
2179- // Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180- std::string reduceMode;
2181- if (!matchPattern (op.getReduce (), m_TorchConstantStr (reduceMode)))
2182- return rewriter.notifyMatchFailure (
2183- op, " only support constant str reduce mode" );
2184- // TODO: add "prod", "mean", "amax", "amin" mode.
2185- if (reduceMode != " sum" )
2186- return rewriter.notifyMatchFailure (
2187- op, " Only support sum reduce mode for now" );
2188-
2189- // Get dim.
2190- int64_t dim;
2191- if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
2192- return rewriter.notifyMatchFailure (op, " dim must be constant" );
2193-
2194- // Prepare input.
2195- auto self = adaptor.getSelf ();
2196- auto selfType = cast<RankedTensorType>(self.getType ());
2197- int64_t selfRank = selfType.getRank ();
2198- // TODO: add more input rank support.
2199- if (selfRank > 1 || dim > selfRank - 1 )
2200- return rewriter.notifyMatchFailure (op,
2201- " Only support self rank==1 for now" );
2202-
2203- // Prepare index.
2204- Value index = adaptor.getIndex ();
2205- auto indexType = cast<RankedTensorType>(index.getType ());
2206- int64_t indexRank = indexType.getRank ();
2207- SmallVector<int64_t > indexAbstractSizes (indexRank, kUnknownSize );
2208- auto abstractIndexType =
2209- RankedTensorType::get (makeShapeLLVMCompatible (indexAbstractSizes),
2210- indexType.getElementType ());
2211- Value abstractindex =
2212- rewriter.create <tensor::CastOp>(loc, abstractIndexType, index);
2213-
2214- // Prepare src.
2215- Value src = adaptor.getSrc ();
2216- auto srcType = cast<RankedTensorType>(src.getType ());
2217- int64_t srcRank = srcType.getRank ();
2218- SmallVector<int64_t > srcAbstractSizes (srcRank, kUnknownSize );
2219- auto abstractSrcType = RankedTensorType::get (
2220- makeShapeLLVMCompatible (srcAbstractSizes), srcType.getElementType ());
2221- Value abstractSrc =
2222- rewriter.create <tensor::CastOp>(loc, abstractSrcType, src);
2223-
2224- // Prepare result type.
2225- const TypeConverter *typeConverter = getTypeConverter ();
2226- RankedTensorType resultType = cast<RankedTensorType>(
2227- typeConverter->convertType (op->getResult (0 ).getType ()));
2228-
2229- // Prepare indexingMaps and iteratorTypes.
2230- SmallVector<AffineMap, 3 > indexingMaps = {
2231- rewriter.getMultiDimIdentityMap (indexRank),
2232- rewriter.getMultiDimIdentityMap (srcRank),
2233- rewriter.getMultiDimIdentityMap (selfRank),
2234- };
2235- // Prepare iteratorTypes.
2236- SmallVector<utils::IteratorType> iteratorTypes{
2237- 1 , utils::IteratorType::parallel};
2238-
2239- // Implementation of scatter and reduce in linalg.generic.
2240- bool err = false ;
2241- Value result =
2242- rewriter
2243- .create <linalg::GenericOp>(
2244- loc, /* resultTensorTypes=*/ self.getType (),
2245- /* inputs=*/ ValueRange ({abstractindex, abstractSrc}),
2246- /* outputs=*/ self, indexingMaps, iteratorTypes,
2247- [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248- // Scatter result after reduce accumulation.
2249- Value scatter = createLinalgPayloadForReduceScatterOp (
2250- builder, loc, op, payloadArgs, self, dim, reduceMode,
2251- resultType);
2252- // Return selfElements to itself, nothing change but a
2253- // placeholder.
2254- if (scatter) {
2255- builder.create <linalg::YieldOp>(
2256- loc, /* selfElement*/ payloadArgs[2 ]);
2257- }
2258- err = !scatter;
2259- })
2260- .getResult (0 );
2261-
2262- if (err)
2263- return rewriter.notifyMatchFailure (
2264- op,
2265- " failed to create linalg.generic operation for reduce scatter op" );
2266- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
2267-
2268- return success ();
2269- }
2270- };
2271- } // namespace
2272-
22732127namespace {
22742128class ConvertAtenViewAsComplexOp
22752129 : public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2810,8 +2664,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
28102664 patterns.add <ConvertAtenCopyOp>(typeConverter, context);
28112665 target.addIllegalOp <AtenSliceScatterOp>();
28122666 patterns.add <ConvertAtenSliceScatterOp>(typeConverter, context);
2813- target.addIllegalOp <AtenScatterReduceOp>();
2814- patterns.add <ConvertAtenScatterReduceOp>(typeConverter, context);
28152667 target.addIllegalOp <AtenViewAsComplexOp>();
28162668 patterns.add <ConvertAtenViewAsComplexOp>(typeConverter, context);
28172669 target.addIllegalOp <AtenViewAsRealOp>();
0 commit comments