99
1010#include " PassDetail.h"
1111
12+ #include " ReifyAbstractInterpCalculationsUtils.h"
1213#include " mlir/Transforms/DialectConversion.h"
1314#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
1415#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15- #include " ReifyAbstractInterpCalculationsUtils.h"
1616#include " llvm/ADT/StringExtras.h"
1717
1818using namespace mlir ;
@@ -72,8 +72,8 @@ namespace {
7272// immutable tensors.
7373class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
7474public:
75- ConvertHasValueSemanticsOpsToValueTensors (MLIRContext *context,
76- const std::optional<SymbolTable>& extraLibrary)
75+ ConvertHasValueSemanticsOpsToValueTensors (
76+ MLIRContext *context, const std::optional<SymbolTable> & extraLibrary)
7777 : RewritePattern(MatchAnyOpTypeTag(), /* benefit=*/ 1 , context) {
7878 this ->extraLibrary = extraLibrary;
7979 }
@@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
8787 return rewriter.notifyMatchFailure (op, " does not have value semantics" );
8888 }
8989
90- rewriter.startRootUpdate (op);
90+ rewriter.startOpModification (op);
9191 // Convert all operands.
9292 SmallVector<Value> newOperands;
9393 for (OpOperand &opOperand : op->getOpOperands ()) {
@@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
105105 auto listConstruct =
106106 opOperand.get ().getDefiningOp <PrimListConstructOp>();
107107 if (!listConstruct) {
108- rewriter.cancelRootUpdate (op);
108+ rewriter.cancelOpModification (op);
109109 return rewriter.notifyMatchFailure (
110110 op, " unimplemented: list of non vtensor type not constructed "
111111 " from list construct" );
@@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
120120 if (!llvm::all_of (listConstruct.getElements (), [](Value val) {
121121 return val.getType ().isa <NonValueTensorType, Torch::NoneType>();
122122 })) {
123- rewriter.cancelRootUpdate (op);
123+ rewriter.cancelOpModification (op);
124124 return rewriter.notifyMatchFailure (
125125 op, " unimplemented: list containing optional type is not "
126126 " handled." );
@@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
138138
139139 Type newListType = getContainerOrTensorTypeWithValueSemantics (listType);
140140 if (!newListType) {
141- rewriter.cancelRootUpdate (op);
141+ rewriter.cancelOpModification (op);
142142 return rewriter.notifyMatchFailure (
143143 op, " Unable to convert list type to value semantics." );
144144 }
@@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
154154 // from the non value tensor of the original optional value.
155155 auto derefine = opOperand.get ().getDefiningOp <DerefineOp>();
156156 if (!derefine) {
157- rewriter.cancelRootUpdate (op);
157+ rewriter.cancelOpModification (op);
158158 return rewriter.notifyMatchFailure (
159159 op, " unimplemented: optional of non vtensor type not from "
160160 " derefine" );
@@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
180180 rewriter.create <CopyToNonValueTensorOp>(op->getLoc (), result);
181181 result.replaceAllUsesExcept (nonValueTensor, nonValueTensor);
182182 }
183- rewriter.finalizeRootUpdate (op);
183+ rewriter.finalizeOpModification (op);
184184 return success ();
185185 }
186+
186187private:
187188 std::optional<SymbolTable> extraLibrary;
188189};
@@ -290,17 +291,18 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
290291 Operation *newOp = rewriter.create (state);
291292 // Note: need to convert result to first input's dtype because mix precision
292293 // compute would result in different behaviors.
293- // For example:
294- // a = torch.randn(3, 3).half() # float16
295- // b = torch.randn(3, 3) # float32
294+ // For example:
295+ // a = torch.randn(3, 3).half() # float16
296+ // b = torch.randn(3, 3) # float32
296297 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16
297298 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
298299 Value none = rewriter.create <ConstantNoneOp>(op->getLoc ());
299300 Value cstFalse = rewriter.create <ConstantBoolOp>(op->getLoc (), false );
300301 auto aDtype = rewriter.create <PrimDtypeOp>(op->getLoc (), op->getOperand (0 ));
301302 auto toDtype = rewriter.create <AtenToDtypeOp>(
302303 op->getLoc (), newOp->getResult (0 ).getType (), newOp->getResult (0 ),
303- aDtype, /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse, /* memory_format=*/ none);
304+ aDtype, /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
305+ /* memory_format=*/ none);
304306 auto tensor = rewriter.create <CopyToValueTensorOp>(op->getLoc (), toDtype);
305307 createOverwriteTensorContents (rewriter, op->getLoc (), tensor,
306308 op->getOperand (0 ));
0 commit comments