22
22
23
23
#include < numeric>
24
24
25
- #include " mlir/IR/Matchers .h"
25
+ #include " mlir/Analysis/TopologicalSortUtils .h"
26
26
#include " mlir/IR/PatternMatch.h"
27
27
#include " mlir/Pass/Pass.h"
28
- #include " mlir/Transforms/DialectConversion .h"
28
+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
29
29
#include " llvm/Support/Debug.h"
30
30
31
31
#include " src/Dialect/ONNX/DialectBuilder.hpp"
@@ -76,7 +76,6 @@ ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
76
76
splitShape[axis] = size;
77
77
resultTypes.push_back (RankedTensorType::get (splitShape, elementType));
78
78
}
79
- rewriter.setInsertionPointAfter (input.getDefiningOp ());
80
79
// Perform Split Operation
81
80
ValueRange results =
82
81
create.onnx .split (ArrayRef (resultTypes), input, splitConstant, axis);
@@ -657,8 +656,13 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
657
656
ONNXConvOp convOp1, PatternRewriter &rewriter) const final {
658
657
Value input = convOp1.getX ();
659
658
if (!onnx_mlir::isRankedShapedType (input.getType ()) ||
660
- mlir::cast<ShapedType>(input.getType ()).hasStaticShape () == false )
661
- return failure ();
659
+ !mlir::cast<ShapedType>(input.getType ()).hasStaticShape ())
660
+ return rewriter.notifyMatchFailure (
661
+ convOp1, " input must be a ranked tensor with static shape" );
662
+
663
+ if (!cast<ShapedType>(convOp1.getType ()).hasStaticShape ())
664
+ return rewriter.notifyMatchFailure (
665
+ convOp1, " output type must be a ranked tensor with static shape" );
662
666
663
667
// Collect all ONNXConvOps using this input.
664
668
SmallVector<ONNXConvOp> candidateConvs;
@@ -669,38 +673,88 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
669
673
670
674
// Must have at least two convs to combine.
671
675
if (candidateConvs.size () < 2 )
672
- return failure ();
676
+ return rewriter.notifyMatchFailure (
677
+ convOp1, " not enough conv ops to combine" );
673
678
674
679
// Ensure all candidate convs are compatible (including bias check).
675
680
for (size_t i = 1 ; i < candidateConvs.size (); ++i) {
676
681
if (!areCompatible (candidateConvs[0 ], candidateConvs[i]))
677
- return failure ();
682
+ return rewriter.notifyMatchFailure (
683
+ convOp1, " conv ops are not compatible for combining" );
678
684
}
679
685
680
686
auto totalUses = static_cast <size_t >(
681
687
std::distance (input.getUsers ().begin (), input.getUsers ().end ()));
682
688
if (candidateConvs.size () != totalUses)
683
- return failure ();
689
+ return rewriter.notifyMatchFailure (
690
+ convOp1, " number of candidate convs does not match input uses" );
684
691
685
692
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
686
693
694
+ SmallVector<Value> weightValues;
695
+ int64_t totalOutputChannels = 0 ;
696
+ for (auto conv : parallelConvs) {
697
+ auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
698
+ if (!weightType.hasStaticShape ())
699
+ return rewriter.notifyMatchFailure (
700
+ conv, " weight must be a ranked tensor with static shape" );
701
+ if (!cast<ShapedType>(conv.getType ()).hasStaticShape ())
702
+ return rewriter.notifyMatchFailure (
703
+ conv, " output type must be a ranked tensor with static shape" );
704
+ weightValues.push_back (conv.getW ());
705
+ totalOutputChannels += weightType.getShape ()[0 ];
706
+ }
707
+
708
+ auto *latestConv =
709
+ llvm::max_element (parallelConvs, [](ONNXConvOp a, ONNXConvOp b) {
710
+ return a->isBeforeInBlock (b.getOperation ());
711
+ });
712
+
713
+ const auto checkIfOtherConvsReachable = [&](ONNXConvOp conv) {
714
+ SmallVector<Operation *> worklist;
715
+ DenseSet<Operation *> visited;
716
+ worklist.push_back (conv.getOperation ());
717
+ while (!worklist.empty ()) {
718
+ Operation *current = worklist.back ();
719
+ worklist.pop_back ();
720
+
721
+ for (auto *user : current->getUsers ()) {
722
+ if (auto otherConv = dyn_cast<ONNXConvOp>(user)) {
723
+ if (llvm::is_contained (parallelConvs, otherConv)) {
724
+ // Found another conv that is part of the parallel convs.
725
+ return true ;
726
+ }
727
+ }
728
+ if (visited.insert (user).second &&
729
+ user->isBeforeInBlock (*latestConv)) {
730
+ worklist.push_back (user);
731
+ }
732
+ };
733
+ }
734
+ return false ;
735
+ };
736
+ // Ensure all convolutions are really parallel, none of then can be part of
737
+ // the input of another convolution
738
+ if (llvm::any_of (parallelConvs, checkIfOtherConvsReachable)) {
739
+ return rewriter.notifyMatchFailure (
740
+ convOp1, " conv ops are not parallel (reachable from each other)" );
741
+ }
742
+
687
743
bool allHaveBias = !mlir::isa<NoneType>(parallelConvs[0 ].getB ().getType ());
744
+
688
745
Location loc = convOp1.getLoc ();
746
+ for (auto conv : parallelConvs) {
747
+ loc = rewriter.getFusedLoc ({loc, conv.getLoc ()});
748
+ }
689
749
auto inputType = mlir::cast<ShapedType>(input.getType ());
690
750
Type elementType = inputType.getElementType ();
691
751
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create (
692
752
rewriter, loc);
753
+ OpBuilder::InsertionGuard guard (rewriter);
754
+ rewriter.setInsertionPointAfter (*latestConv);
693
755
694
756
int64_t concatAxis = 1 ;
695
757
696
- SmallVector<Value> weightValues;
697
- int64_t totalOutputChannels = 0 ;
698
- for (auto conv : parallelConvs) {
699
- auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
700
- weightValues.push_back (conv.getW ());
701
- totalOutputChannels += weightType.getShape ()[0 ];
702
- }
703
-
704
758
auto firstWeightType =
705
759
mlir::cast<ShapedType>(parallelConvs[0 ].getW ().getType ());
706
760
SmallVector<int64_t > newWeightShape (
@@ -720,9 +774,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
720
774
Type newBiasType = RankedTensorType::get (newBiasShape, elementType);
721
775
newBias = create.onnx .concat (newBiasType, biasValues, 0 );
722
776
} else {
723
- // Bias is absent for all. Assign a null Value (nullptr) instead of
724
- // ONNXNoneOp.
725
- newBias = nullptr ;
777
+ newBias = parallelConvs[0 ].getB ();
726
778
}
727
779
728
780
SmallVector<int64_t > newOutputShape (
@@ -767,8 +819,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
767
819
768
820
if (allOutputsUsedInCommonConcat && commonConcatOp &&
769
821
commonConcatOp.getAxis () == 1 ) {
770
- commonConcatOp.getResult ().replaceAllUsesWith (newConv.getResult ());
771
- rewriter.eraseOp (commonConcatOp);
822
+ rewriter.replaceOp (commonConcatOp, newConv);
772
823
} else {
773
824
SmallVector<int64_t > splitSizesVec;
774
825
for (auto conv : parallelConvs) {
@@ -777,15 +828,15 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
777
828
splitSizesVec.push_back (channels);
778
829
}
779
830
780
- rewriter.setInsertionPointAfter (newConv);
781
831
ValueRange splitResults = onnx_mlir::emitSplitByChannels (
782
832
rewriter, loc, newConv.getResult (), splitSizesVec, concatAxis);
783
-
784
833
for (size_t i = 0 ; i < parallelConvs.size (); ++i) {
785
- parallelConvs[i]. getResult (). replaceAllUsesWith ( splitResults[i]);
834
+ rewriter. replaceAllOpUsesWith ( parallelConvs[i], splitResults[i]);
786
835
}
836
+ // Sort the block topological, as the operations after the split may be in
837
+ // the wrong place otherwise
838
+ mlir::sortTopologically (newConv->getBlock ());
787
839
}
788
-
789
840
for (auto conv : parallelConvs) {
790
841
rewriter.eraseOp (conv);
791
842
}
@@ -851,44 +902,10 @@ void RecomposeONNXToONNXPass::runOnOperation() {
851
902
func::FuncOp function = getOperation ();
852
903
MLIRContext *context = &getContext ();
853
904
854
- ConversionTarget target (getContext ());
855
- target.addLegalDialect <ONNXDialect, arith::ArithDialect, func::FuncDialect>();
856
-
857
- // These ops will be Recomposed into other ONNX ops. Hence, they will not be
858
- // available after this pass.
859
-
860
- // Recompose LayerNorm, starting from scale/mul op
861
- target.addDynamicallyLegalOp <ONNXMulOp>([](ONNXMulOp op) {
862
- Value x, scale;
863
- FloatAttr epsilon;
864
- int64_t axis;
865
- bool isRMSLayerNorm;
866
- if (RecomposeLayerNormFromMulPattern::matchLayerNormPattern (
867
- op, x, scale, axis, epsilon, isRMSLayerNorm))
868
- return false ;
869
-
870
- bool isExactGelu;
871
- if (RecomposeGeluFromMulPattern::matchGeluPattern (op, x, isExactGelu))
872
- return false ;
873
-
874
- return true ;
875
- });
876
-
877
- // Recompose QLinearMatMul, starting from QuantizeLinear.
878
- // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
879
- target.addDynamicallyLegalOp <ONNXQuantizeLinearOp>(
880
- [](ONNXQuantizeLinearOp op) {
881
- Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
882
- outZeroPoint;
883
- return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
884
- matchQLinearMatMulPattern (op, a, aScale, aZeroPoint, b, bScale,
885
- bZeroPoint, outScale, outZeroPoint);
886
- });
887
-
888
905
RewritePatternSet patterns (context);
889
906
onnx_mlir::getRecomposeONNXToONNXPatterns (patterns);
890
907
891
- if (failed (applyPartialConversion (function, target , std::move (patterns))))
908
+ if (failed (applyPatternsGreedily (function, std::move (patterns))))
892
909
signalPassFailure ();
893
910
}
894
911
0 commit comments