Skip to content

Commit 07e2d59

Browse files
committed
[ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp for onnx.ScatterElements
This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext Remove the wrong AtenScatterReduce to linalg pass.
1 parent e632755 commit 07e2d59

File tree

3 files changed

+8
-153
lines changed

3 files changed

+8
-153
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,10 +645,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
645645

646646
Value cstStrReduction =
647647
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
648-
649-
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
648+
Value cstTrue =
649+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
650+
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
650651
binder.op, resultType, data, constAxis, indices, updates,
651-
cstStrReduction);
652+
cstStrReduction, cstTrue);
652653
return success();
653654
});
654655
patterns.onOp(

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,152 +2124,6 @@ class ConvertAtenSliceScatterOp
21242124
};
21252125
} // namespace
21262126

2127-
namespace {
2128-
static Value
2129-
createLinalgPayloadForReduceScatterOp(OpBuilder &b, Location loc, Operation *op,
2130-
ValueRange payloadArgs, Value self,
2131-
int64_t dim, std::string reduceMode,
2132-
RankedTensorType resultType) {
2133-
Type resultElementType = resultType.getElementType();
2134-
Value index = castIntToIndex(b, loc, /*abstractindexElement*/ payloadArgs[0]);
2135-
// Get the element at self[index].
2136-
auto selfElement = b.create<tensor::ExtractOp>(loc, self, ValueRange{index});
2137-
// Get the element at src[index].
2138-
Value srcElement = convertScalarToDtype(
2139-
b, loc, /*abstractSrcElement*/ payloadArgs[1], resultElementType);
2140-
Value accumulatorElement;
2141-
// Reduce the elements based on different mode.
2142-
// TODO add more reduce mode here.
2143-
if (reduceMode == "sum") {
2144-
// accumulatorElement = selfElement + srcElement;
2145-
if (isa<mlir::FloatType>(resultElementType))
2146-
accumulatorElement =
2147-
b.create<arith::AddFOp>(loc, selfElement, srcElement);
2148-
else if (isa<mlir::IntegerType>(resultElementType))
2149-
accumulatorElement =
2150-
b.create<arith::AddIOp>(loc, selfElement, srcElement);
2151-
} else {
2152-
op->emitError("only sum lowering in createLinalgPayloadForReduceScatterOp");
2153-
return nullptr;
2154-
}
2155-
// Prepare source, indices, scatter_dims for scatter op.
2156-
Value accumulatorElementTensor =
2157-
b.create<tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158-
Value indexTensor = b.create<tensor::FromElementsOp>(loc, ValueRange{index});
2159-
ArrayRef<int64_t> dimArray{dim};
2160-
auto scatter = b.create<tensor::ScatterOp>(
2161-
loc, resultType, /*source*/ accumulatorElementTensor,
2162-
/*dest*/ self, /*indices*/ indexTensor,
2163-
/*scatter_dims*/ b.getDenseI64ArrayAttr(dimArray),
2164-
/*unique*/ b.getUnitAttr());
2165-
return scatter;
2166-
}
2167-
2168-
class ConvertAtenScatterReduceOp
2169-
: public OpConversionPattern<AtenScatterReduceOp> {
2170-
public:
2171-
using OpConversionPattern::OpConversionPattern;
2172-
LogicalResult
2173-
matchAndRewrite(AtenScatterReduceOp op, OpAdaptor adaptor,
2174-
ConversionPatternRewriter &rewriter) const override {
2175-
Location loc = op.getLoc();
2176-
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
2177-
return failure();
2178-
2179-
// Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180-
std::string reduceMode;
2181-
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceMode)))
2182-
return rewriter.notifyMatchFailure(
2183-
op, "only support constant str reduce mode");
2184-
// TODO: add "prod", "mean", "amax", "amin" mode.
2185-
if (reduceMode != "sum")
2186-
return rewriter.notifyMatchFailure(
2187-
op, "Only support sum reduce mode for now");
2188-
2189-
// Get dim.
2190-
int64_t dim;
2191-
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2192-
return rewriter.notifyMatchFailure(op, "dim must be constant");
2193-
2194-
// Prepare input.
2195-
auto self = adaptor.getSelf();
2196-
auto selfType = cast<RankedTensorType>(self.getType());
2197-
int64_t selfRank = selfType.getRank();
2198-
// TODO: add more input rank support.
2199-
if (selfRank > 1 || dim > selfRank - 1)
2200-
return rewriter.notifyMatchFailure(op,
2201-
"Only support self rank==1 for now");
2202-
2203-
// Prepare index.
2204-
Value index = adaptor.getIndex();
2205-
auto indexType = cast<RankedTensorType>(index.getType());
2206-
int64_t indexRank = indexType.getRank();
2207-
SmallVector<int64_t> indexAbstractSizes(indexRank, kUnknownSize);
2208-
auto abstractIndexType =
2209-
RankedTensorType::get(makeShapeLLVMCompatible(indexAbstractSizes),
2210-
indexType.getElementType());
2211-
Value abstractindex =
2212-
rewriter.create<tensor::CastOp>(loc, abstractIndexType, index);
2213-
2214-
// Prepare src.
2215-
Value src = adaptor.getSrc();
2216-
auto srcType = cast<RankedTensorType>(src.getType());
2217-
int64_t srcRank = srcType.getRank();
2218-
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
2219-
auto abstractSrcType = RankedTensorType::get(
2220-
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
2221-
Value abstractSrc =
2222-
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
2223-
2224-
// Prepare result type.
2225-
const TypeConverter *typeConverter = getTypeConverter();
2226-
RankedTensorType resultType = cast<RankedTensorType>(
2227-
typeConverter->convertType(op->getResult(0).getType()));
2228-
2229-
// Prepare indexingMaps and iteratorTypes.
2230-
SmallVector<AffineMap, 3> indexingMaps = {
2231-
rewriter.getMultiDimIdentityMap(indexRank),
2232-
rewriter.getMultiDimIdentityMap(srcRank),
2233-
rewriter.getMultiDimIdentityMap(selfRank),
2234-
};
2235-
// Prepare iteratorTypes.
2236-
SmallVector<utils::IteratorType> iteratorTypes{
2237-
1, utils::IteratorType::parallel};
2238-
2239-
// Implementation of scatter and reduce in linalg.generic.
2240-
bool err = false;
2241-
Value result =
2242-
rewriter
2243-
.create<linalg::GenericOp>(
2244-
loc, /*resultTensorTypes=*/self.getType(),
2245-
/*inputs=*/ValueRange({abstractindex, abstractSrc}),
2246-
/*outputs=*/self, indexingMaps, iteratorTypes,
2247-
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248-
// Scatter result after reduce accumulation.
2249-
Value scatter = createLinalgPayloadForReduceScatterOp(
2250-
builder, loc, op, payloadArgs, self, dim, reduceMode,
2251-
resultType);
2252-
// Return selfElements to itself, nothing change but a
2253-
// placeholder.
2254-
if (scatter) {
2255-
builder.create<linalg::YieldOp>(
2256-
loc, /*selfElement*/ payloadArgs[2]);
2257-
}
2258-
err = !scatter;
2259-
})
2260-
.getResult(0);
2261-
2262-
if (err)
2263-
return rewriter.notifyMatchFailure(
2264-
op,
2265-
"failed to create linalg.generic operation for reduce scatter op");
2266-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
2267-
2268-
return success();
2269-
}
2270-
};
2271-
} // namespace
2272-
22732127
namespace {
22742128
class ConvertAtenViewAsComplexOp
22752129
: public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2810,8 +2664,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
28102664
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
28112665
target.addIllegalOp<AtenSliceScatterOp>();
28122666
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
2813-
target.addIllegalOp<AtenScatterReduceOp>();
2814-
patterns.add<ConvertAtenScatterReduceOp>(typeConverter, context);
28152667
target.addIllegalOp<AtenViewAsComplexOp>();
28162668
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
28172669
target.addIllegalOp<AtenViewAsRealOp>();

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
269269
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270270
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271271
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
272-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
272+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
273+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR:.*]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
273274
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
274275
return %0 : !torch.vtensor<[1,5],f32>
275276
}
@@ -302,7 +303,8 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],
302303
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303304
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304305
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
305-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
306+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
307+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR:.*]], %[[TRUE]]: !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
306308
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
307309
return %0 : !torch.vtensor<[1,5],f32>
308310
}

0 commit comments

Comments
 (0)