Skip to content

Commit c9647c8

Browse files
authored
Use greedy rewriter instead of partial conversion in recomposition pass and fuse locations when fusing convolutions. Also fix bugs in the parallel conv recomposition (#3207)
* Use greedy pattern rewriter instead of partial conversion in recomposition. The greedy rewriter is generally easier to use, as operations do not need to be marked as illegal Signed-off-by: Jonas Rickert <[email protected]> Signed-off-by: Rickert, Jonas <[email protected]> * Emit useful mesages when failing to match in CombineParallelConv2DPattern Signed-off-by: Jonas Rickert <[email protected]> Signed-off-by: Rickert, Jonas <[email protected]> * Fuse locations when fusing convs Signed-off-by: Jonas Rickert <[email protected]> Signed-off-by: Rickert, Jonas <[email protected]> * Check that there are no def-use chains between convs before combining them Signed-off-by: Rickert, Jonas <[email protected]> * Check for static shapes in parallel conv recomposition Signed-off-by: Rickert, Jonas <[email protected]> * Use rewriter methods instead of 'raw' IR manipulation. Bypassing the rewriter can lead to subtle bugs as listeners do not get notified. Modify the insertion point and use a final topological sort to ensure the IR/graph order is valid, no matter how th einput orders are ordered and from which conv the recomposition starts. Signed-off-by: Rickert, Jonas <[email protected]> * Add missing end newline Signed-off-by: Rickert, Jonas <[email protected]> * Use none instead of nullptr for bias and make sure the insertion point is after all weights Signed-off-by: Rickert, Jonas <[email protected]> --------- Signed-off-by: Jonas Rickert <[email protected]> Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 471a7c3 commit c9647c8

File tree

3 files changed

+216
-61
lines changed

3 files changed

+216
-61
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
#include <numeric>
2424

25-
#include "mlir/IR/Matchers.h"
25+
#include "mlir/Analysis/TopologicalSortUtils.h"
2626
#include "mlir/IR/PatternMatch.h"
2727
#include "mlir/Pass/Pass.h"
28-
#include "mlir/Transforms/DialectConversion.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2929
#include "llvm/Support/Debug.h"
3030

3131
#include "src/Dialect/ONNX/DialectBuilder.hpp"
@@ -76,7 +76,6 @@ ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
7676
splitShape[axis] = size;
7777
resultTypes.push_back(RankedTensorType::get(splitShape, elementType));
7878
}
79-
rewriter.setInsertionPointAfter(input.getDefiningOp());
8079
// Perform Split Operation
8180
ValueRange results =
8281
create.onnx.split(ArrayRef(resultTypes), input, splitConstant, axis);
@@ -657,8 +656,13 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
657656
ONNXConvOp convOp1, PatternRewriter &rewriter) const final {
658657
Value input = convOp1.getX();
659658
if (!onnx_mlir::isRankedShapedType(input.getType()) ||
660-
mlir::cast<ShapedType>(input.getType()).hasStaticShape() == false)
661-
return failure();
659+
!mlir::cast<ShapedType>(input.getType()).hasStaticShape())
660+
return rewriter.notifyMatchFailure(
661+
convOp1, "input must be a ranked tensor with static shape");
662+
663+
if (!cast<ShapedType>(convOp1.getType()).hasStaticShape())
664+
return rewriter.notifyMatchFailure(
665+
convOp1, "output type must be a ranked tensor with static shape");
662666

663667
// Collect all ONNXConvOps using this input.
664668
SmallVector<ONNXConvOp> candidateConvs;
@@ -669,38 +673,88 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
669673

670674
// Must have at least two convs to combine.
671675
if (candidateConvs.size() < 2)
672-
return failure();
676+
return rewriter.notifyMatchFailure(
677+
convOp1, "not enough conv ops to combine");
673678

674679
// Ensure all candidate convs are compatible (including bias check).
675680
for (size_t i = 1; i < candidateConvs.size(); ++i) {
676681
if (!areCompatible(candidateConvs[0], candidateConvs[i]))
677-
return failure();
682+
return rewriter.notifyMatchFailure(
683+
convOp1, "conv ops are not compatible for combining");
678684
}
679685

680686
auto totalUses = static_cast<size_t>(
681687
std::distance(input.getUsers().begin(), input.getUsers().end()));
682688
if (candidateConvs.size() != totalUses)
683-
return failure();
689+
return rewriter.notifyMatchFailure(
690+
convOp1, "number of candidate convs does not match input uses");
684691

685692
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
686693

694+
SmallVector<Value> weightValues;
695+
int64_t totalOutputChannels = 0;
696+
for (auto conv : parallelConvs) {
697+
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
698+
if (!weightType.hasStaticShape())
699+
return rewriter.notifyMatchFailure(
700+
conv, "weight must be a ranked tensor with static shape");
701+
if (!cast<ShapedType>(conv.getType()).hasStaticShape())
702+
return rewriter.notifyMatchFailure(
703+
conv, "output type must be a ranked tensor with static shape");
704+
weightValues.push_back(conv.getW());
705+
totalOutputChannels += weightType.getShape()[0];
706+
}
707+
708+
auto *latestConv =
709+
llvm::max_element(parallelConvs, [](ONNXConvOp a, ONNXConvOp b) {
710+
return a->isBeforeInBlock(b.getOperation());
711+
});
712+
713+
const auto checkIfOtherConvsReachable = [&](ONNXConvOp conv) {
714+
SmallVector<Operation *> worklist;
715+
DenseSet<Operation *> visited;
716+
worklist.push_back(conv.getOperation());
717+
while (!worklist.empty()) {
718+
Operation *current = worklist.back();
719+
worklist.pop_back();
720+
721+
for (auto *user : current->getUsers()) {
722+
if (auto otherConv = dyn_cast<ONNXConvOp>(user)) {
723+
if (llvm::is_contained(parallelConvs, otherConv)) {
724+
// Found another conv that is part of the parallel convs.
725+
return true;
726+
}
727+
}
728+
if (visited.insert(user).second &&
729+
user->isBeforeInBlock(*latestConv)) {
730+
worklist.push_back(user);
731+
}
732+
};
733+
}
734+
return false;
735+
};
736+
// Ensure all convolutions are really parallel, none of then can be part of
737+
// the input of another convolution
738+
if (llvm::any_of(parallelConvs, checkIfOtherConvsReachable)) {
739+
return rewriter.notifyMatchFailure(
740+
convOp1, "conv ops are not parallel (reachable from each other)");
741+
}
742+
687743
bool allHaveBias = !mlir::isa<NoneType>(parallelConvs[0].getB().getType());
744+
688745
Location loc = convOp1.getLoc();
746+
for (auto conv : parallelConvs) {
747+
loc = rewriter.getFusedLoc({loc, conv.getLoc()});
748+
}
689749
auto inputType = mlir::cast<ShapedType>(input.getType());
690750
Type elementType = inputType.getElementType();
691751
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
692752
rewriter, loc);
753+
OpBuilder::InsertionGuard guard(rewriter);
754+
rewriter.setInsertionPointAfter(*latestConv);
693755

694756
int64_t concatAxis = 1;
695757

696-
SmallVector<Value> weightValues;
697-
int64_t totalOutputChannels = 0;
698-
for (auto conv : parallelConvs) {
699-
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
700-
weightValues.push_back(conv.getW());
701-
totalOutputChannels += weightType.getShape()[0];
702-
}
703-
704758
auto firstWeightType =
705759
mlir::cast<ShapedType>(parallelConvs[0].getW().getType());
706760
SmallVector<int64_t> newWeightShape(
@@ -720,9 +774,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
720774
Type newBiasType = RankedTensorType::get(newBiasShape, elementType);
721775
newBias = create.onnx.concat(newBiasType, biasValues, 0);
722776
} else {
723-
// Bias is absent for all. Assign a null Value (nullptr) instead of
724-
// ONNXNoneOp.
725-
newBias = nullptr;
777+
newBias = parallelConvs[0].getB();
726778
}
727779

728780
SmallVector<int64_t> newOutputShape(
@@ -767,8 +819,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
767819

768820
if (allOutputsUsedInCommonConcat && commonConcatOp &&
769821
commonConcatOp.getAxis() == 1) {
770-
commonConcatOp.getResult().replaceAllUsesWith(newConv.getResult());
771-
rewriter.eraseOp(commonConcatOp);
822+
rewriter.replaceOp(commonConcatOp, newConv);
772823
} else {
773824
SmallVector<int64_t> splitSizesVec;
774825
for (auto conv : parallelConvs) {
@@ -777,15 +828,15 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
777828
splitSizesVec.push_back(channels);
778829
}
779830

780-
rewriter.setInsertionPointAfter(newConv);
781831
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
782832
rewriter, loc, newConv.getResult(), splitSizesVec, concatAxis);
783-
784833
for (size_t i = 0; i < parallelConvs.size(); ++i) {
785-
parallelConvs[i].getResult().replaceAllUsesWith(splitResults[i]);
834+
rewriter.replaceAllOpUsesWith(parallelConvs[i], splitResults[i]);
786835
}
836+
// Sort the block topological, as the operations after the split may be in
837+
// the wrong place otherwise
838+
mlir::sortTopologically(newConv->getBlock());
787839
}
788-
789840
for (auto conv : parallelConvs) {
790841
rewriter.eraseOp(conv);
791842
}
@@ -851,44 +902,10 @@ void RecomposeONNXToONNXPass::runOnOperation() {
851902
func::FuncOp function = getOperation();
852903
MLIRContext *context = &getContext();
853904

854-
ConversionTarget target(getContext());
855-
target.addLegalDialect<ONNXDialect, arith::ArithDialect, func::FuncDialect>();
856-
857-
// These ops will be Recomposed into other ONNX ops. Hence, they will not be
858-
// available after this pass.
859-
860-
// Recompose LayerNorm, starting from scale/mul op
861-
target.addDynamicallyLegalOp<ONNXMulOp>([](ONNXMulOp op) {
862-
Value x, scale;
863-
FloatAttr epsilon;
864-
int64_t axis;
865-
bool isRMSLayerNorm;
866-
if (RecomposeLayerNormFromMulPattern::matchLayerNormPattern(
867-
op, x, scale, axis, epsilon, isRMSLayerNorm))
868-
return false;
869-
870-
bool isExactGelu;
871-
if (RecomposeGeluFromMulPattern::matchGeluPattern(op, x, isExactGelu))
872-
return false;
873-
874-
return true;
875-
});
876-
877-
// Recompose QLinearMatMul, starting from QuantizeLinear.
878-
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
879-
target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
880-
[](ONNXQuantizeLinearOp op) {
881-
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
882-
outZeroPoint;
883-
return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
884-
matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
885-
bZeroPoint, outScale, outZeroPoint);
886-
});
887-
888905
RewritePatternSet patterns(context);
889906
onnx_mlir::getRecomposeONNXToONNXPatterns(patterns);
890907

891-
if (failed(applyPartialConversion(function, target, std::move(patterns))))
908+
if (failed(applyPatternsGreedily(function, std::move(patterns))))
892909
signalPassFailure();
893910
}
894911

0 commit comments

Comments
 (0)