Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -12640,6 +12642,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -15333,6 +15336,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [
Expand Down
134 changes: 134 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

// ===----------------------------------------------------------------------===//
// AtenRSubScalarOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenMulTensorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}

// ===----------------------------------------------------------------------===//
// AtenDivTensorModeOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}
std::function<double(ArrayRef<double>)> fpFold;
std::function<APInt(ArrayRef<APInt>)> intFold;

auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}

fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
if (!roundMode) {
return (double)inputs[0] / inputs[1];
} else if (roundMode.getValue().str() == "floor") {
return std::floor((double)inputs[0] / inputs[1]);
} else {
return std::trunc((double)inputs[0] / inputs[1]);
}
};

intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
int64_t res;
if (roundMode.getValue().str() == "floor") {
res = std::floor(lhs / rhs);
} else {
res = std::trunc(lhs / rhs);
}
return APInt(bits, res);
};

if (!roundMode) {
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, std::nullopt);
}

return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenDivScalarModeOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3584,6 +3654,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
}

// ===----------------------------------------------------------------------===//
// AtenRemainderScalarOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}

auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return std::fmod(inputs[0], inputs[1]);
};

auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
return ret;
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenAddIntOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4216,6 +4314,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
auto value = adaptor.getA();
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
if (!dense || !dense.isSplat()) {
return nullptr;
}

auto splat = dense.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
auto type = getType();
if (!isa<mlir::IntegerType>(type)) {
return nullptr;
}

if (type.isSignlessInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getInt());
} else if (type.isSignedInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
} else {
return getI64IntegerAttr(getContext(), intAttr.getUInt());
}
}

if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getI64IntegerAttr(
getContext(),
static_cast<long>(floatAttr.getValue().convertToDouble()));
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenFloatTensorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def emit_with_mutating_variants(key, **kwargs):
# variants.
emit_with_mutating_variants(
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
has_folder=True,
has_canonicalizer=True,
)
emit_with_mutating_variants(
Expand Down Expand Up @@ -481,6 +482,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
Expand Down Expand Up @@ -928,7 +930,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
)
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
emit(
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
Expand Down Expand Up @@ -1080,7 +1084,7 @@ def emit_with_mutating_variants(key, **kwargs):
has_canonicalizer=True,
)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
Expand Down
110 changes: 110 additions & 0 deletions test/Dialect/Torch/torch-nary-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_rsub_scalar_int
func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_rsub_scalar_float
func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_remainder_scalar_int
func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_remainder_scalar_float
func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_int
func.func @fold_aten_int_tensor_int() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_bool
func.func @fold_aten_int_tensor_bool() -> !torch.int {
// CHECK: %int1 = torch.constant.int 1
%cst_false = torch.vtensor.literal(dense<true> : tensor<i1>) : !torch.vtensor<[],i1>
%0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_float
func.func @fold_aten_int_tensor_float() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3.1> : tensor<f32>) : !torch.vtensor<[],f32>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_int
func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%trunc = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_float
func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%floor = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_none
func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%none = torch.constant.none
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
Loading