@@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
6363 return success ();
6464}
6565
66+ LogicalResult constructListFromLiteral (PatternRewriter &rewriter,
67+ ValueTensorLiteralOp literalOp,
68+ SmallVector<Value> &vals) {
69+ // only supports splat ValueTensorLiterals for now. TODO: add support for
70+ // small non-splat valuetensorliterals.
71+ auto ty = dyn_cast<ValueTensorType>(literalOp.getType ());
72+ if (!ty || !ty.hasSizes ())
73+ return failure ();
74+ auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue ());
75+ if (!attr)
76+ return failure ();
77+ auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue <Attribute>());
78+ if (!attrInt)
79+ return failure ();
80+ IntegerType intty = cast<IntegerType>(attrInt.getType ());
81+ if (!intty.isSignedInteger ())
82+ return failure ();
83+ Value materializedVal = rewriter.create <Torch::ConstantIntOp>(
84+ literalOp.getLoc (), attrInt.getSInt ());
85+ vals.resize (vals.size () + ty.getSizes ()[0 ], materializedVal);
86+ return success ();
87+ }
88+
6689LogicalResult getListFromTensor (Value value, SmallVector<Value> &vals) {
6790 constexpr int64_t kMaxFold = 16 ;
6891 if (auto tensor = value.getDefiningOp <Torch::AtenTensorOp>())
@@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern
351374};
352375} // namespace
353376
377+ namespace {
378+ class PropagateAtenWhereSelfPattern : public OpRewritePattern <AtenWhereSelfOp> {
379+ public:
380+ using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
381+ LogicalResult matchAndRewrite (AtenWhereSelfOp op,
382+ PatternRewriter &rewriter) const override {
383+ Value condition = op.getCondition ();
384+ Value self = op.getSelf ();
385+ Value other = op.getOther ();
386+ auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType ());
387+ if (!conditionTy || !conditionTy.hasSizes () ||
388+ conditionTy.getSizes ().size () != 1 )
389+ return rewriter.notifyMatchFailure (op, " bad condition type" );
390+ auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType ());
391+ if (!selfTy || !selfTy.hasSizes () || selfTy.getSizes ().size () != 1 )
392+ return rewriter.notifyMatchFailure (op, " bad self type" );
393+ auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType ());
394+ if (!otherTy || !otherTy.hasSizes () || otherTy.getSizes ().size () != 1 )
395+ return rewriter.notifyMatchFailure (op, " bad other type" );
396+ int64_t conditionSize = selfTy.getSizes ()[0 ];
397+ int64_t selfSize = selfTy.getSizes ()[0 ];
398+ int64_t otherSize = otherTy.getSizes ()[0 ];
399+
400+ if (selfSize != otherSize || selfSize != conditionSize)
401+ return rewriter.notifyMatchFailure (
402+ op,
403+ " unimplemented: support for propogating with implicit broadcasting." );
404+
405+ constexpr int64_t kMaxFold = 16 ;
406+ if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold )
407+ return rewriter.notifyMatchFailure (op,
408+ " arguments are dynamic or too big" );
409+
410+ SmallVector<Value> conditionList, selfList, otherList;
411+ if (failed (getListFromTensor (condition, conditionList)) ||
412+ (int64_t )conditionList.size () != conditionSize)
413+ return failure ();
414+
415+ // If one of these tensors is a value tensor literal op, we will need to
416+ // create constant ints in the IR to form a list. Before calling
417+ // constructListFromLiteral, we must be certain that the conversion can no
418+ // longer fail, otherwise we will cause an infinite loop of creating a
419+ // constant and removing it.
420+ LogicalResult selfFromList = getListFromTensor (self, selfList);
421+ LogicalResult otherFromList = getListFromTensor (other, otherList);
422+
423+ if (failed (selfFromList) && failed (otherFromList))
424+ return rewriter.notifyMatchFailure (
425+ op, " At least one operand must succeed at constructing a list" );
426+
427+ auto selfLiteral = self.getDefiningOp <Torch::ValueTensorLiteralOp>();
428+ auto otherLiteral = other.getDefiningOp <Torch::ValueTensorLiteralOp>();
429+ if (succeeded (selfFromList) && otherLiteral &&
430+ failed (constructListFromLiteral (rewriter, otherLiteral, otherList)))
431+ return failure ();
432+ if (succeeded (otherFromList) && selfLiteral &&
433+ failed (constructListFromLiteral (rewriter, selfLiteral, selfList)))
434+ return failure ();
435+ if ((int64_t )selfList.size () != selfSize ||
436+ (int64_t )otherList.size () != otherSize)
437+ // this should only occur if we did not generate IR with
438+ // constructListFromLiteral
439+ return failure ();
440+
441+ Location loc = op.getLoc ();
442+ SmallVector<Value> whereVals;
443+ auto rank0IntTy = rewriter.getType <Torch::ValueTensorType>(
444+ ArrayRef<int64_t >({}), selfTy.getDtype ());
445+ auto rank0BoolTy = rewriter.getType <Torch::ValueTensorType>(
446+ ArrayRef<int64_t >({}), conditionTy.getDtype ());
447+ for (uint64_t i = 0 ; i < selfList.size (); i++) {
448+ Value rank0Cond = rewriter.create <Torch::PrimNumToTensorScalarOp>(
449+ loc, rank0BoolTy, conditionList[i]);
450+ Value rank0Self = rewriter.create <Torch::PrimNumToTensorScalarOp>(
451+ loc, rank0IntTy, selfList[i]);
452+ Value rank0Other = rewriter.create <Torch::PrimNumToTensorScalarOp>(
453+ loc, rank0IntTy, otherList[i]);
454+ Value rank0Where = rewriter.create <AtenWhereSelfOp>(
455+ loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
456+ whereVals.push_back (rewriter.create <AtenItemOp>(
457+ loc, rewriter.getType <Torch::IntType>(), rank0Where));
458+ }
459+ Value list = rewriter.create <Torch::PrimListConstructOp>(
460+ op.getLoc (), Torch::ListType::get (whereVals[0 ].getType ()), whereVals);
461+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(op.getLoc ());
462+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
463+ op.getLoc (), rewriter.getBoolAttr (false ));
464+ rewriter.replaceOpWithNewOp <Torch::AtenTensorOp>(
465+ op, op.getType (), list, cstNone, cstNone, cstFalse);
466+ return success ();
467+ }
468+ };
469+ } // namespace
470+
471+ namespace {
472+ class PropagateAtenEqTensorPattern : public OpRewritePattern <AtenEqTensorOp> {
473+ public:
474+ using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
475+ LogicalResult matchAndRewrite (AtenEqTensorOp op,
476+ PatternRewriter &rewriter) const override {
477+ Value self = op.getSelf ();
478+ Value other = op.getOther ();
479+ auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType ());
480+ if (!selfTy || !selfTy.hasSizes () || selfTy.getSizes ().size () != 1 )
481+ return rewriter.notifyMatchFailure (op, " bad self type" );
482+ auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType ());
483+ if (!otherTy || !otherTy.hasSizes () || otherTy.getSizes ().size () != 1 )
484+ return rewriter.notifyMatchFailure (op, " bad other type" );
485+ int64_t selfSize = selfTy.getSizes ()[0 ];
486+ int64_t otherSize = otherTy.getSizes ()[0 ];
487+
488+ if (selfSize != otherSize)
489+ return rewriter.notifyMatchFailure (
490+ op,
491+ " unimplemented: support for propogating with implicit broadcasting." );
492+
493+ constexpr int64_t kMaxFold = 16 ;
494+ if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
495+ otherSize == Torch::kUnknownSize || otherSize > kMaxFold )
496+ return rewriter.notifyMatchFailure (op,
497+ " self or other is dynamic or too big" );
498+
499+ SmallVector<Value> selfList, otherList;
500+ // If one of these tensors is a value tensor literal op, we will need to
501+ // create constant ints in the IR to form a list. Before calling
502+ // constructListFromLiteral, we must be certain that the conversion can no
503+ // longer fail, otherwise we will cause an infinite loop of creating a
504+ // constant and removing it.
505+ LogicalResult selfFromList = getListFromTensor (self, selfList);
506+ LogicalResult otherFromList = getListFromTensor (other, otherList);
507+
508+ if (failed (selfFromList) && failed (otherFromList))
509+ return rewriter.notifyMatchFailure (
510+ op, " At least one operand must succeed at constructing a list" );
511+
512+ auto selfLiteral = self.getDefiningOp <Torch::ValueTensorLiteralOp>();
513+ auto otherLiteral = other.getDefiningOp <Torch::ValueTensorLiteralOp>();
514+ if (succeeded (selfFromList) && otherLiteral &&
515+ failed (constructListFromLiteral (rewriter, otherLiteral, otherList)))
516+ return failure ();
517+ if (succeeded (otherFromList) && selfLiteral &&
518+ failed (constructListFromLiteral (rewriter, selfLiteral, selfList)))
519+ return failure ();
520+ if ((int64_t )selfList.size () != selfSize ||
521+ (int64_t )otherList.size () != otherSize)
522+ // this should only occur if we did not generate IR with
523+ // constructListFromLiteral
524+ return failure ();
525+
526+ SmallVector<Value> eqVals;
527+ for (uint64_t i = 0 ; i < selfList.size (); i++) {
528+ eqVals.push_back (
529+ rewriter.create <AtenEqIntOp>(op.getLoc (), selfList[i], otherList[i]));
530+ }
531+ Value list = rewriter.create <Torch::PrimListConstructOp>(
532+ op.getLoc (), Torch::ListType::get (eqVals[0 ].getType ()), eqVals);
533+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(op.getLoc ());
534+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
535+ op.getLoc (), rewriter.getBoolAttr (false ));
536+ rewriter.replaceOpWithNewOp <Torch::AtenTensorOp>(
537+ op, op.getType (), list, cstNone, cstNone, cstFalse);
538+ return success ();
539+ }
540+ };
541+ } // namespace
542+
354543namespace {
355544class PropagateAtenItemPattern : public OpRewritePattern <AtenItemOp> {
356545public:
@@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
454643};
455644} // namespace
456645
646+ namespace {
647+ class FoldAtenSqueezeDimPattern : public OpRewritePattern <AtenSqueezeDimOp> {
648+ public:
649+ using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
650+ LogicalResult matchAndRewrite (AtenSqueezeDimOp op,
651+ PatternRewriter &rewriter) const override {
652+ auto resultTy = cast<ValueTensorType>(op.getType ());
653+ if (!resultTy.hasSizes () || resultTy.getSizes ().size () != 0 )
654+ return rewriter.notifyMatchFailure (op, " Unknown result shape" );
655+
656+ if (auto atenFull = op.getSelf ().getDefiningOp <AtenFullOp>()) {
657+ rewriter.replaceOpWithNewOp <PrimNumToTensorScalarOp>(
658+ op, resultTy, atenFull.getFillValue ());
659+ return success ();
660+ }
661+ return failure ();
662+ }
663+ };
664+ } // namespace
665+
457666namespace {
458667class FoldAtenWhereSelf : public OpRewritePattern <AtenWhereSelfOp> {
459668public:
@@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
694903 PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
695904 FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
696905 FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
906+ PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
907+ FoldAtenSqueezeDimPattern,
697908 RemoveUnusedPattern<Torch::AtenIntBoolOp>,
698909 RemoveUnusedPattern<Torch::AtenEqIntOp>,
699910 RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
0 commit comments