@@ -1665,6 +1665,106 @@ struct RemoveDimZeroInputInConcatPattern
16651665 }
16661666};
16671667
1668+ // =============================================================================
1669+ // Rewrite pattern onnx.dim
1670+ // =============================================================================
1671+
1672+ // / The pattern is to replace DimOp by a dimension value in the shape of
1673+ // / ReshapeOp.
1674+ // /
1675+ // / We would like to replace:
1676+ // / ```
1677+ // / %shape = onnx.Concat(%d0, %d1, %d2)
1678+ // / %reshape = onnx.Reshape(%X, %shape1) {allowzero = 0}
1679+ // / %dim = onnx.Dim(%reshape) {axis = 1}
1680+ // / ```
1681+ // / with
1682+ // / ```
1683+ // / %dim = %d1
1684+ // / ```
1685+ // / We only consider `allowzero=0` in this pattern.
1686+
1687+ struct DimOpFromReshapeInputPattern : public OpRewritePattern <ONNXDimOp> {
1688+ using OpRewritePattern<ONNXDimOp>::OpRewritePattern;
1689+
1690+ LogicalResult matchAndRewrite (
1691+ ONNXDimOp dimOp, PatternRewriter &rewriter) const final {
1692+ Value data = dimOp.getData ();
1693+
1694+ // Donot handle unranked tensors.
1695+ if (!isRankedShapedType (data.getType ()))
1696+ return rewriter.notifyMatchFailure (dimOp, " Input is unranked" );
1697+
1698+ // Normalize axis.
1699+ int64_t rank = getRank (data.getType ());
1700+ int64_t dimAxis = dimOp.getAxis ();
1701+ if (dimAxis < 0 )
1702+ dimAxis += rank;
1703+
1704+ // Dim is from a reshape op.
1705+ ONNXReshapeOp reshapeOp = data.getDefiningOp <ONNXReshapeOp>();
1706+ if (!reshapeOp)
1707+ return rewriter.notifyMatchFailure (dimOp, " Not found reshape op" );
1708+ if (reshapeOp.getAllowzero () != 0 )
1709+ return rewriter.notifyMatchFailure (dimOp, " Reshape op's allowzero != 0" );
1710+
1711+ // Shape is from a concat op of dims.
1712+ Value shape = reshapeOp.getShape ();
1713+ ONNXConcatOp concatOp = shape.getDefiningOp <ONNXConcatOp>();
1714+ if (!concatOp)
1715+ return rewriter.notifyMatchFailure (dimOp, " Not found concat op" );
1716+ ValueRange dims = concatOp.getInputs ();
1717+
1718+ // Ensure that the number of concat's inputs is equal to rank.
1719+ if (dims.size () != static_cast <uint64_t >(rank))
1720+ return rewriter.notifyMatchFailure (
1721+ dimOp, " Concat input size is not good" );
1722+
1723+ // Values in shape can be -1 or 0 according to Reshape's definition.
1724+ // Those values are not real dimensions.
1725+ if (isConstOf (dims[dimAxis], -1 ) || (isConstOf (dims[dimAxis], 0 )))
1726+ return rewriter.notifyMatchFailure (dimOp, " Dim at axis is -1 or 0" );
1727+
1728+ rewriter.replaceOp (dimOp, dims[dimAxis]);
1729+ return success ();
1730+ }
1731+ };
1732+
1733+ struct DimOpFromTransposeInputPattern : public OpRewritePattern <ONNXDimOp> {
1734+ using OpRewritePattern<ONNXDimOp>::OpRewritePattern;
1735+
1736+ LogicalResult matchAndRewrite (
1737+ ONNXDimOp dimOp, PatternRewriter &rewriter) const final {
1738+ Value data = dimOp.getData ();
1739+
1740+ // Donot handle unranked tensors.
1741+ if (!isRankedShapedType (data.getType ()))
1742+ return rewriter.notifyMatchFailure (dimOp, " Input is unranked" );
1743+
1744+ // Normalize axis.
1745+ int64_t rank = getRank (data.getType ());
1746+ int64_t dimAxis = dimOp.getAxis ();
1747+ if (dimAxis < 0 )
1748+ dimAxis += rank;
1749+
1750+ // Dim is from a transpose op.
1751+ ONNXTransposeOp transposeOp = data.getDefiningOp <ONNXTransposeOp>();
1752+ if (!transposeOp)
1753+ return rewriter.notifyMatchFailure (dimOp, " Not found transpose op" );
1754+ Value transposeData = transposeOp.getData ();
1755+
1756+ // Transpose axes.
1757+ ArrayAttr permAttr = transposeOp.getPermAttr ();
1758+ int64_t transposeAxis = ArrayAttrIntVal (permAttr, dimAxis);
1759+
1760+ Value replacedDim =
1761+ OnnxBuilder (rewriter, dimOp.getLoc ()).dim (transposeData, transposeAxis);
1762+
1763+ rewriter.replaceOp (dimOp, replacedDim);
1764+ return success ();
1765+ }
1766+ };
1767+
16681768// =============================================================================
16691769// Rewrite pattern LayerNormalization
16701770// =============================================================================
@@ -2268,6 +2368,8 @@ void ONNXDropoutOp::getCanonicalizationPatterns(
22682368void ONNXDimOp::getCanonicalizationPatterns (
22692369 RewritePatternSet &results, MLIRContext *context) {
22702370 results.insert <DimOpToConstantPattern>(context);
2371+ results.insert <DimOpFromReshapeInputPattern>(context);
2372+ results.insert <DimOpFromTransposeInputPattern>(context);
22712373}
22722374
22732375// / on the ONNXEqualOp.
0 commit comments