Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -892,26 +893,101 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
// The non_blocking arg must be `False`.
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
nonBlocking)
return nullptr;
return {};
// The copy arg must be `False`.
if (!matchPattern(getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
return nullptr;
return {};
// The memory_format arg must be `none`.
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
return nullptr;
return {};

auto inputType = cast<BaseTensorType>(getSelf().getType());
auto resType = cast<BaseTensorType>(getType());
// If the types aren't equal, then we can't fold.
if (inputType != resType)
return nullptr;

// Fold when both the input tensor and result are of the same type.
// If the type does not have a statically known dtype, then we cannot fold.
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
// since the `unk` could be dynamically different for the operand and result.
if (!inputType.hasDtype())
return nullptr;
// Fold when both the input tensor and result are of the same type.
return getOperand(0);
if (inputType == resType && inputType.hasDtype())
return getOperand(0);

// Fold conversion of splat values.
auto elems = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
if (!elems || !elems.isSplat())
return {};

auto outVTy = dyn_cast<ValueTensorType>(getType());
if (!outVTy)
return {};

auto outShaped = outVTy.toBuiltinTensor();
if (!outShaped.hasStaticShape())
return {};

Type srcEltTy = inputType.getDtype();
Type dstEltTy = outVTy.getDtype();

// Handle integer destination.
if (auto dstI = dyn_cast<IntegerType>(dstEltTy)) {
// any -> bool(i1).
if (dstI.isSignlessInteger(1)) {
bool truthy = false;
if (isa<mlir::FloatType>(srcEltTy)) {
const APFloat &floatVal = elems.getSplatValue<APFloat>();
truthy = !floatVal.isZero();
} else {
const APInt &intVal = elems.getSplatValue<APInt>();
truthy = !intVal.isZero();
}
return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy));
}
// float -> intN
if (auto srcF = dyn_cast<mlir::FloatType>(srcEltTy)) {
APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger());
bool isExact = false;
APFloat f = elems.getSplatValue<APFloat>();
APFloat::opStatus st =
f.convertToInteger(result, APFloat::rmTowardZero, &isExact);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, APInt(result));
return {}; // NaN/Inf/out-of-range: preserve runtime semantics.
}
// intM -> intN
const APInt &v = elems.getSplatValue<APInt>();
auto isUnsigned = cast<IntegerType>(srcEltTy).isUnsignedInteger();
auto isSignless = cast<IntegerType>(srcEltTy).isSignlessInteger();
APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth())
: v.sextOrTrunc(dstI.getWidth());
return DenseElementsAttr::get(outShaped, casted);
}

// Handle float destination.
if (auto dstF = dyn_cast<mlir::FloatType>(dstEltTy)) {
const llvm::fltSemantics &dstSem = dstF.getFloatSemantics();

// int -> float
if (auto srcI = dyn_cast<IntegerType>(srcEltTy)) {
APFloat f(dstSem);
APFloat::opStatus st = f.convertFromAPInt(
elems.getSplatValue<APInt>(),
/*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(),
APFloat::rmNearestTiesToEven);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, f);
return {};
}

// floatX -> floatY
APFloat f = elems.getSplatValue<APFloat>();
bool losesInfo = false;
APFloat::opStatus st =
f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, f);
return {};
}

return {};
}

//===----------------------------------------------------------------------===//
Expand Down
176 changes: 176 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,179 @@ def forward(self, x):
@register_test_case(module_factory=lambda: PrimsConvertElementTypeModule())
def PrimsConvertElementTypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))


# ==============================================================================


class ToDtypeConstIntFromDoubleModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([1.1], dtype=torch.float64)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.int64,
)


@register_test_case(module_factory=lambda: ToDtypeConstIntFromDoubleModule())
def ToDtypeConstIntFromDoubleModule_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstInt32FromInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([2147483648], dtype=torch.int64)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.int32,
)


@register_test_case(module_factory=lambda: ToDtypeConstInt32FromInt64Module())
def ToDtypeConstInt32FromInt64Module_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstFloat16FromFloat64Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([1.2345], dtype=torch.float64)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.float16,
)


@register_test_case(module_factory=lambda: ToDtypeConstFloat16FromFloat64Module())
def ToDtypeConstFloat16FromFloat64Module_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstBFloat16FromFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([-0.5101], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.float16,
)


@register_test_case(module_factory=lambda: ToDtypeConstBFloat16FromFloat32Module())
def ToDtypeConstBFloat16FromFloat32Module_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstBoolFromInt32ZeroModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([0], dtype=torch.int32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.bool,
)


@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32ZeroModule())
def ToDtypeConstBoolFromInt32ZeroModule_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstBoolFromInt32NonZeroIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([32], dtype=torch.int32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.bool,
)


@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32NonZeroIntModule())
def ToDtypeConstBoolFromInt32NonZeroIntModule_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstBoolFromFloat32NonZeroNanModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([float("nan")], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.bool,
)


@register_test_case(
module_factory=lambda: ToDtypeConstBoolFromFloat32NonZeroNanModule()
)
def ToDtypeConstBoolFromFloat32NonZeroNanModule_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstFloat32FromBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([True], dtype=torch.bool)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.float32,
)


@register_test_case(module_factory=lambda: ToDtypeConstFloat32FromBoolModule())
def ToDtypeConstFloat32FromBoolModule_basic(module, tu: TestUtils):
module.forward()


class ToDtypeConstInt32FromBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([True], dtype=torch.bool)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.to(
self.const,
dtype=torch.int32,
)


@register_test_case(module_factory=lambda: ToDtypeConstInt32FromBoolModule())
def ToDtypeConstInt32FromBoolModule_basic(module, tu: TestUtils):
module.forward()
88 changes: 88 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,94 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
return %0 : !torch.tensor
}

// CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
func.func @torch.aten.to.dtype$fold_splat() -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1>) {
// CHECK-NOT: torch.aten.to.dtype
%false = torch.constant.bool false
%none = torch.constant.none

// int32 splat → float32
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
%int6 = torch.constant.int 6 // torch.float32
// CHECK: %[[R1:.*]] = torch.vtensor.literal(dense<4.200000e+01> : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
%result1 = torch.aten.to.dtype %int_splat, %int6, %false, %false, %none
: !torch.vtensor<[2,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[2,3],f32>

// float32 splat → int32 (rmTowardZero)
%float_splat = torch.vtensor.literal(dense<3.14159> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
%int3 = torch.constant.int 3 // torch.int32
// CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
%result2 = torch.aten.to.dtype %float_splat, %int3, %false, %false, %none
: !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[4,4],si32>

// int64 splat (max int32 + 1) → int32 (trunc)
%int64_splat = torch.vtensor.literal(dense<2147483648> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
// CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<-2147483648> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
%result3 = torch.aten.to.dtype %int64_splat, %int3, %false, %false, %none
: !torch.vtensor<[10],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[10],si32>

// float32 splat → float64
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32>
%int7 = torch.constant.int 7 // torch.float64
// CHECK: %[[R4:.*]] = torch.vtensor.literal(dense<2.7182800769805908> : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
%result4 = torch.aten.to.dtype %float32_splat, %int7, %false, %false, %none
: !torch.vtensor<[5,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[5,5],f64>

// float64 splat → float16
%float64_splat = torch.vtensor.literal(dense<1.2> : tensor<3x3xf64>) : !torch.vtensor<[3,3],f64>
%int5 = torch.constant.int 5 // torch.float16
// CHECK: %[[R5:.*]] = torch.vtensor.literal(dense<1.200200e+00> : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
%result5 = torch.aten.to.dtype %float64_splat, %int5, %false, %false, %none
: !torch.vtensor<[3,3],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[3,3],f16>

// float32 splat → bfloat16
%float32_bf16 = torch.vtensor.literal(dense<-0.51> : tensor<2x2xf32>) : !torch.vtensor<[2,2],f32>
%int15 = torch.constant.int 15 // torch.bfloat16
// CHECK: %[[R6:.*]] = torch.vtensor.literal(dense<-5.117190e-01> : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
%result6 = torch.aten.to.dtype %float32_bf16, %int15, %false, %false, %none
: !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[2,2],bf16>

// int32 splat → int64 (sign-extend)
%int32_ext = torch.vtensor.literal(dense<-1000> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
%int4 = torch.constant.int 4 // torch.int64
// CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%result7 = torch.aten.to.dtype %int32_ext, %int4, %false, %false, %none
: !torch.vtensor<[4],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[4],si64>

// int32 splat → int16 (trunc)
%int32_trunc = torch.vtensor.literal(dense<32768> : tensor<3xsi32>) : !torch.vtensor<[3],si32>
%int2 = torch.constant.int 2 // torch.int16
// CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<-32768> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
%result8 = torch.aten.to.dtype %int32_trunc, %int2, %false, %false, %none
: !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[3],si16>

// int32 splat → bool (i1), non-zero
%int40_splat = torch.vtensor.literal(dense<40> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
%int11 = torch.constant.int 11 // torch.bool
// CHECK: %[[R9:.*]] = torch.vtensor.literal(dense<true> : tensor<2xi1>) : !torch.vtensor<[2],i1>
%result9 = torch.aten.to.dtype %int40_splat, %int11, %false, %false, %none
: !torch.vtensor<[2],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[2],i1>

// float32 splat → bool (i1), zero
%float_zero = torch.vtensor.literal(dense<0.0> : tensor<2xf32>) : !torch.vtensor<[2],f32>
// CHECK: %[[R11:.*]] = torch.vtensor.literal(dense<false> : tensor<2xi1>) : !torch.vtensor<[2],i1>
%result10 = torch.aten.to.dtype %float_zero, %int11, %false, %false, %none
: !torch.vtensor<[2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
-> !torch.vtensor<[2],i1>

return %result1, %result2, %result3, %result4, %result5, %result6, %result7, %result8, %result9, %result10
: !torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1>
}

// CHECK-LABEL: func.func @torch.aten.to.other$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
Expand Down
Loading
Loading