diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bf2a605c950b..29b55969227a 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" @@ -892,26 +893,101 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { // The non_blocking arg must be `False`. if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || nonBlocking) - return nullptr; + return {}; // The copy arg must be `False`. if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg) - return nullptr; + return {}; // The memory_format arg must be `none`. if (!isa(getMemoryFormat().getType())) - return nullptr; + return {}; auto inputType = cast(getSelf().getType()); auto resType = cast(getType()); - // If the types aren't equal, then we can't fold. - if (inputType != resType) - return nullptr; + + // Fold when both the input tensor and result are of the same type. // If the type does not have a statically known dtype, then we cannot fold. // For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong, // since the `unk` could be dynamically different for the operand and result. - if (!inputType.hasDtype()) - return nullptr; - // Fold when both the input tensor and result are of the same type. - return getOperand(0); + if (inputType == resType && inputType.hasDtype()) + return getOperand(0); + + // Fold conversion of splat values. + auto elems = dyn_cast_or_null(adaptor.getSelf()); + if (!elems || !elems.isSplat()) + return {}; + + auto outVTy = dyn_cast(getType()); + if (!outVTy) + return {}; + + auto outShaped = outVTy.toBuiltinTensor(); + if (!outShaped.hasStaticShape()) + return {}; + + Type srcEltTy = inputType.getDtype(); + Type dstEltTy = outVTy.getDtype(); + + // Handle integer destination. + if (auto dstI = dyn_cast(dstEltTy)) { + // any -> bool(i1). + if (dstI.isSignlessInteger(1)) { + bool truthy = false; + if (isa(srcEltTy)) { + const APFloat &floatVal = elems.getSplatValue(); + truthy = !floatVal.isZero(); + } else { + const APInt &intVal = elems.getSplatValue(); + truthy = !intVal.isZero(); + } + return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy)); + } + // float -> intN + if (auto srcF = dyn_cast(srcEltTy)) { + APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger()); + bool isExact = false; + APFloat f = elems.getSplatValue(); + APFloat::opStatus st = + f.convertToInteger(result, APFloat::rmTowardZero, &isExact); + if (st == APFloat::opOK || st == APFloat::opInexact) + return DenseElementsAttr::get(outShaped, APInt(result)); + return {}; // NaN/Inf/out-of-range: preserve runtime semantics. + } + // intM -> intN + const APInt &v = elems.getSplatValue(); + auto isUnsigned = cast(srcEltTy).isUnsignedInteger(); + auto isSignless = cast(srcEltTy).isSignlessInteger(); + APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth()) + : v.sextOrTrunc(dstI.getWidth()); + return DenseElementsAttr::get(outShaped, casted); + } + + // Handle float destination. + if (auto dstF = dyn_cast(dstEltTy)) { + const llvm::fltSemantics &dstSem = dstF.getFloatSemantics(); + + // int -> float + if (auto srcI = dyn_cast(srcEltTy)) { + APFloat f(dstSem); + APFloat::opStatus st = f.convertFromAPInt( + elems.getSplatValue(), + /*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(), + APFloat::rmNearestTiesToEven); + if (st == APFloat::opOK || st == APFloat::opInexact) + return DenseElementsAttr::get(outShaped, f); + return {}; + } + + // floatX -> floatY + APFloat f = elems.getSplatValue(); + bool losesInfo = false; + APFloat::opStatus st = + f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo); + if (st == APFloat::opOK || st == APFloat::opInexact) + return DenseElementsAttr::get(outShaped, f); + return {}; + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 156ed7959351..43a7584a7eb3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -370,3 +370,179 @@ def forward(self, x): @register_test_case(module_factory=lambda: PrimsConvertElementTypeModule()) def PrimsConvertElementTypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + + +# ============================================================================== + + +class ToDtypeConstIntFromDoubleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([1.1], dtype=torch.float64) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.int64, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstIntFromDoubleModule()) +def ToDtypeConstIntFromDoubleModule_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstInt32FromInt64Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([2147483648], dtype=torch.int64) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.int32, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstInt32FromInt64Module()) +def ToDtypeConstInt32FromInt64Module_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstFloat16FromFloat64Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([1.2345], dtype=torch.float64) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.float16, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstFloat16FromFloat64Module()) +def ToDtypeConstFloat16FromFloat64Module_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstBFloat16FromFloat32Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([-0.5101], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.float16, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstBFloat16FromFloat32Module()) +def ToDtypeConstBFloat16FromFloat32Module_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstBoolFromInt32ZeroModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([0], dtype=torch.int32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.bool, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32ZeroModule()) +def ToDtypeConstBoolFromInt32ZeroModule_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstBoolFromInt32NonZeroIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([32], dtype=torch.int32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.bool, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32NonZeroIntModule()) +def ToDtypeConstBoolFromInt32NonZeroIntModule_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstBoolFromFloat32NonZeroNanModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([float("nan")], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.bool, + ) + + +@register_test_case( + module_factory=lambda: ToDtypeConstBoolFromFloat32NonZeroNanModule() +) +def ToDtypeConstBoolFromFloat32NonZeroNanModule_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstFloat32FromBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([True], dtype=torch.bool) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.float32, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstFloat32FromBoolModule()) +def ToDtypeConstFloat32FromBoolModule_basic(module, tu: TestUtils): + module.forward() + + +class ToDtypeConstInt32FromBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([True], dtype=torch.bool) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.to( + self.const, + dtype=torch.int32, + ) + + +@register_test_case(module_factory=lambda: ToDtypeConstInt32FromBoolModule()) +def ToDtypeConstInt32FromBoolModule_basic(module, tu: TestUtils): + module.forward() diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 34fd1b886e26..e410d52d5576 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1762,6 +1762,94 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch return %0 : !torch.tensor } +// CHECK-LABEL: @torch.aten.to.dtype$fold_splat( +func.func @torch.aten.to.dtype$fold_splat() -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1>) { + // CHECK-NOT: torch.aten.to.dtype + %false = torch.constant.bool false + %none = torch.constant.none + + // int32 splat → float32 + %int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32> + %int6 = torch.constant.int 6 // torch.float32 + // CHECK: %[[R1:.*]] = torch.vtensor.literal(dense<4.200000e+01> : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32> + %result1 = torch.aten.to.dtype %int_splat, %int6, %false, %false, %none + : !torch.vtensor<[2,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[2,3],f32> + + // float32 splat → int32 (rmTowardZero) + %float_splat = torch.vtensor.literal(dense<3.14159> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32> + %int3 = torch.constant.int 3 // torch.int32 + // CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32> + %result2 = torch.aten.to.dtype %float_splat, %int3, %false, %false, %none + : !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[4,4],si32> + + // int64 splat (max int32 + 1) → int32 (trunc) + %int64_splat = torch.vtensor.literal(dense<2147483648> : tensor<10xsi64>) : !torch.vtensor<[10],si64> + // CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<-2147483648> : tensor<10xsi32>) : !torch.vtensor<[10],si32> + %result3 = torch.aten.to.dtype %int64_splat, %int3, %false, %false, %none + : !torch.vtensor<[10],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[10],si32> + + // float32 splat → float64 + %float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32> + %int7 = torch.constant.int 7 // torch.float64 + // CHECK: %[[R4:.*]] = torch.vtensor.literal(dense<2.7182800769805908> : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64> + %result4 = torch.aten.to.dtype %float32_splat, %int7, %false, %false, %none + : !torch.vtensor<[5,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[5,5],f64> + + // float64 splat → float16 + %float64_splat = torch.vtensor.literal(dense<1.2> : tensor<3x3xf64>) : !torch.vtensor<[3,3],f64> + %int5 = torch.constant.int 5 // torch.float16 + // CHECK: %[[R5:.*]] = torch.vtensor.literal(dense<1.200200e+00> : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16> + %result5 = torch.aten.to.dtype %float64_splat, %int5, %false, %false, %none + : !torch.vtensor<[3,3],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[3,3],f16> + + // float32 splat → bfloat16 + %float32_bf16 = torch.vtensor.literal(dense<-0.51> : tensor<2x2xf32>) : !torch.vtensor<[2,2],f32> + %int15 = torch.constant.int 15 // torch.bfloat16 + // CHECK: %[[R6:.*]] = torch.vtensor.literal(dense<-5.117190e-01> : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16> + %result6 = torch.aten.to.dtype %float32_bf16, %int15, %false, %false, %none + : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[2,2],bf16> + + // int32 splat → int64 (sign-extend) + %int32_ext = torch.vtensor.literal(dense<-1000> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %int4 = torch.constant.int 4 // torch.int64 + // CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %result7 = torch.aten.to.dtype %int32_ext, %int4, %false, %false, %none + : !torch.vtensor<[4],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[4],si64> + + // int32 splat → int16 (trunc) + %int32_trunc = torch.vtensor.literal(dense<32768> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %int2 = torch.constant.int 2 // torch.int16 + // CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<-32768> : tensor<3xsi16>) : !torch.vtensor<[3],si16> + %result8 = torch.aten.to.dtype %int32_trunc, %int2, %false, %false, %none + : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[3],si16> + + // int32 splat → bool (i1), non-zero + %int40_splat = torch.vtensor.literal(dense<40> : tensor<2xsi32>) : !torch.vtensor<[2],si32> + %int11 = torch.constant.int 11 // torch.bool + // CHECK: %[[R9:.*]] = torch.vtensor.literal(dense : tensor<2xi1>) : !torch.vtensor<[2],i1> + %result9 = torch.aten.to.dtype %int40_splat, %int11, %false, %false, %none + : !torch.vtensor<[2],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[2],i1> + + // float32 splat → bool (i1), zero + %float_zero = torch.vtensor.literal(dense<0.0> : tensor<2xf32>) : !torch.vtensor<[2],f32> + // CHECK: %[[R11:.*]] = torch.vtensor.literal(dense : tensor<2xi1>) : !torch.vtensor<[2],i1> + %result10 = torch.aten.to.dtype %float_zero, %int11, %false, %false, %none + : !torch.vtensor<[2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none + -> !torch.vtensor<[2],i1> + + return %result1, %result2, %result3, %result4, %result5, %result6, %result7, %result8, %result9, %result10 + : !torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1> +} + // CHECK-LABEL: func.func @torch.aten.to.other$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: %[[NONE:.*]] = torch.constant.none diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 4bde19ac15a6..8fe502a7d686 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -159,21 +159,15 @@ func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vt // CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> { // CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 -// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> -// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> -// CHECK: %[[NONE:.+]] = torch.constant.none -// CHECK: %[[FALSE:.+]] = torch.constant.bool false -// CHECK: %[[INT5:.+]] = torch.constant.int 5 -// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1.0{{.*}}> : tensor) : !torch.vtensor<[],f16> +// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0.0{{.*}}> : tensor) : !torch.vtensor<[],f16> +// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1.0{{.*}}> : tensor) : !torch.vtensor<[],f16> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> // CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> // CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> -// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> -// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> -// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16> -// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> -// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V2]], %[[V1]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V0]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> // CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> // CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> // CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>