Skip to content

Commit a6f3181

Browse files
authored
Get dim from the input of reshape and transpose (#3306)
Signed-off-by: Tung D. Le <[email protected]>
1 parent 263a7b2 commit a6f3181

File tree

3 files changed

+137
-0
lines changed

3 files changed

+137
-0
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,106 @@ struct RemoveDimZeroInputInConcatPattern
16651665
}
16661666
};
16671667

1668+
// =============================================================================
1669+
// Rewrite pattern onnx.dim
1670+
// =============================================================================
1671+
1672+
/// The pattern is to replace DimOp by a dimension value in the shape of
1673+
/// ReshapeOp.
1674+
///
1675+
/// We would like to replace:
1676+
/// ```
1677+
/// %shape = onnx.Concat(%d0, %d1, %d2)
1678+
/// %reshape = onnx.Reshape(%X, %shape1) {allowzero = 0}
1679+
/// %dim = onnx.Dim(%reshape) {axis = 1}
1680+
/// ```
1681+
/// with
1682+
/// ```
1683+
/// %dim = %d1
1684+
/// ```
1685+
/// We only consider `allowzero=0` in this pattern.
1686+
1687+
struct DimOpFromReshapeInputPattern : public OpRewritePattern<ONNXDimOp> {
1688+
using OpRewritePattern<ONNXDimOp>::OpRewritePattern;
1689+
1690+
LogicalResult matchAndRewrite(
1691+
ONNXDimOp dimOp, PatternRewriter &rewriter) const final {
1692+
Value data = dimOp.getData();
1693+
1694+
// Donot handle unranked tensors.
1695+
if (!isRankedShapedType(data.getType()))
1696+
return rewriter.notifyMatchFailure(dimOp, "Input is unranked");
1697+
1698+
// Normalize axis.
1699+
int64_t rank = getRank(data.getType());
1700+
int64_t dimAxis = dimOp.getAxis();
1701+
if (dimAxis < 0)
1702+
dimAxis += rank;
1703+
1704+
// Dim is from a reshape op.
1705+
ONNXReshapeOp reshapeOp = data.getDefiningOp<ONNXReshapeOp>();
1706+
if (!reshapeOp)
1707+
return rewriter.notifyMatchFailure(dimOp, "Not found reshape op");
1708+
if (reshapeOp.getAllowzero() != 0)
1709+
return rewriter.notifyMatchFailure(dimOp, "Reshape op's allowzero != 0");
1710+
1711+
// Shape is from a concat op of dims.
1712+
Value shape = reshapeOp.getShape();
1713+
ONNXConcatOp concatOp = shape.getDefiningOp<ONNXConcatOp>();
1714+
if (!concatOp)
1715+
return rewriter.notifyMatchFailure(dimOp, "Not found concat op");
1716+
ValueRange dims = concatOp.getInputs();
1717+
1718+
// Ensure that the number of concat's inputs is equal to rank.
1719+
if (dims.size() != static_cast<uint64_t>(rank))
1720+
return rewriter.notifyMatchFailure(
1721+
dimOp, "Concat input size is not good");
1722+
1723+
// Values in shape can be -1 or 0 according to Reshape's definition.
1724+
// Those values are not real dimensions.
1725+
if (isConstOf(dims[dimAxis], -1) || (isConstOf(dims[dimAxis], 0)))
1726+
return rewriter.notifyMatchFailure(dimOp, "Dim at axis is -1 or 0");
1727+
1728+
rewriter.replaceOp(dimOp, dims[dimAxis]);
1729+
return success();
1730+
}
1731+
};
1732+
1733+
struct DimOpFromTransposeInputPattern : public OpRewritePattern<ONNXDimOp> {
1734+
using OpRewritePattern<ONNXDimOp>::OpRewritePattern;
1735+
1736+
LogicalResult matchAndRewrite(
1737+
ONNXDimOp dimOp, PatternRewriter &rewriter) const final {
1738+
Value data = dimOp.getData();
1739+
1740+
// Donot handle unranked tensors.
1741+
if (!isRankedShapedType(data.getType()))
1742+
return rewriter.notifyMatchFailure(dimOp, "Input is unranked");
1743+
1744+
// Normalize axis.
1745+
int64_t rank = getRank(data.getType());
1746+
int64_t dimAxis = dimOp.getAxis();
1747+
if (dimAxis < 0)
1748+
dimAxis += rank;
1749+
1750+
// Dim is from a transpose op.
1751+
ONNXTransposeOp transposeOp = data.getDefiningOp<ONNXTransposeOp>();
1752+
if (!transposeOp)
1753+
return rewriter.notifyMatchFailure(dimOp, "Not found transpose op");
1754+
Value transposeData = transposeOp.getData();
1755+
1756+
// Transpose axes.
1757+
ArrayAttr permAttr = transposeOp.getPermAttr();
1758+
int64_t transposeAxis = ArrayAttrIntVal(permAttr, dimAxis);
1759+
1760+
Value replacedDim =
1761+
OnnxBuilder(rewriter, dimOp.getLoc()).dim(transposeData, transposeAxis);
1762+
1763+
rewriter.replaceOp(dimOp, replacedDim);
1764+
return success();
1765+
}
1766+
};
1767+
16681768
// =============================================================================
16691769
// Rewrite pattern LayerNormalization
16701770
// =============================================================================
@@ -2268,6 +2368,8 @@ void ONNXDropoutOp::getCanonicalizationPatterns(
22682368
void ONNXDimOp::getCanonicalizationPatterns(
22692369
RewritePatternSet &results, MLIRContext *context) {
22702370
results.insert<DimOpToConstantPattern>(context);
2371+
results.insert<DimOpFromReshapeInputPattern>(context);
2372+
results.insert<DimOpFromTransposeInputPattern>(context);
22712373
}
22722374

22732375
/// on the ONNXEqualOp.

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ WideNum asWideNum(double n, Type elemType) {
597597
/// Checks whether a constant tensor's elements are all equal to a given scalar.
598598
bool isConstOf(Value constValue, double n) {
599599
ElementsAttr constElements = getElementAttributeFromONNXValue(constValue);
600+
if (!constElements)
601+
return false;
600602
Type elemType = constElements.getElementType();
601603
assert(!elemType.isInteger(1) && "booleans are not supported");
602604
WideNum w = asWideNum(n, elemType);

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,39 @@ func.func @test_dim_to_constant(%arg0: tensor<?x256xi64>) -> (tensor<1xi64>) {
13291329

13301330
// -----
13311331

1332+
func.func @dim_from_reshape(%arg0: tensor<1x32x?x?xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<1xi64> {
1333+
%0 = onnx.Constant dense<32> : tensor<1xi64>
1334+
%1 = onnx.Constant dense<1> : tensor<1xi64>
1335+
%2 = "onnx.Dim"(%arg0) {axis = 2 : si64} : (tensor<1x32x?x?xf32>) -> tensor<1xi64>
1336+
%3 = "onnx.Dim"(%arg0) {axis = 3 : si64} : (tensor<1x32x?x?xf32>) -> tensor<1xi64>
1337+
%4 = "onnx.Concat"(%1, %0, %2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
1338+
%5 = "onnx.Reshape"(%arg1, %4) {allowzero = 0 : si64} : (tensor<32x?x?xf32>, tensor<4xi64>) -> tensor<1x32x?x?xf32>
1339+
%6 = "onnx.Dim"(%5) {axis = 2 : si64} : (tensor<1x32x?x?xf32>) -> tensor<1xi64>
1340+
return %6 : tensor<1xi64>
1341+
1342+
// CHECK-LABEL: func.func @dim_from_reshape
1343+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x32x?x?xf32>, [[PARAM_1_:%.+]]: tensor<32x?x?xf32>) -> tensor<1xi64> {
1344+
// CHECK: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 2 : si64} : (tensor<1x32x?x?xf32>) -> tensor<1xi64>
1345+
// CHECK: return [[VAR_0_]] : tensor<1xi64>
1346+
// CHECK: }
1347+
}
1348+
1349+
// -----
1350+
1351+
func.func @dim_from_transpose(%arg0: tensor<1x32x?x?xf32>) -> tensor<1xi64> {
1352+
%0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 1, 3]} : (tensor<1x32x?x?xf32>) -> tensor<1x?x32x?xf32>
1353+
%1 = "onnx.Dim"(%0) {axis = 1 : si64} : (tensor<1x?x32x?xf32>) -> tensor<1xi64>
1354+
return %1 : tensor<1xi64>
1355+
1356+
// CHECK-LABEL: func.func @dim_from_transpose
1357+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x32x?x?xf32>) -> tensor<1xi64> {
1358+
// CHECK: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 2 : si64} : (tensor<1x32x?x?xf32>) -> tensor<1xi64>
1359+
// CHECK: return [[VAR_0_]] : tensor<1xi64>
1360+
// CHECK: }
1361+
}
1362+
1363+
// -----
1364+
13321365
func.func @test_layout_transform(%arg0: tensor<5x3x32x32xf32, #onnx.layout<{dataLayout = "NCHW4C"}>>) -> tensor<5x3x32x32xf32, #onnx.layout<{dataLayout = "NCHW4C"}>> {
13331366
%0 = "onnx.LayoutTransform"(%arg0) {target_layout = #onnx.layout<{dataLayout = "NCHW4C"}>} : (tensor<5x3x32x32xf32,#onnx.layout<{dataLayout = "NCHW4C"}>>) -> tensor<5x3x32x32xf32, #onnx.layout<{dataLayout = "NCHW4C"}>>
13341367
onnx.Return %0 : tensor<5x3x32x32xf32, #onnx.layout<{dataLayout = "NCHW4C"}>>

0 commit comments

Comments
 (0)