@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
17621762 return %0 : !torch.tensor
17631763}
17641764
1765+ // CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766+ func.func @torch.aten.to.dtype$fold_splat () -> (!torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >) {
1767+ %false = torch.constant.bool false
1768+ %none = torch.constant.none
1769+
1770+ // int32 splat → float32
1771+ %int_splat = torch.vtensor.literal (dense <42 > : tensor <2 x3 xsi32 >) : !torch.vtensor <[2 ,3 ],si32 >
1772+ %int6 = torch.constant.int 6 // torch.float32
1773+ // CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1774+ %result1 = torch.aten.to.dtype %int_splat , %int6 , %false , %false , %none
1775+ : !torch.vtensor <[2 ,3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1776+ -> !torch.vtensor <[2 ,3 ],f32 >
1777+
1778+ // float32 splat → int32 (rmTowardZero)
1779+ %float_splat = torch.vtensor.literal (dense <3.14159 > : tensor <4 x4 xf32 >) : !torch.vtensor <[4 ,4 ],f32 >
1780+ %int3 = torch.constant.int 3 // torch.int32
1781+ // CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1782+ %result2 = torch.aten.to.dtype %float_splat , %int3 , %false , %false , %none
1783+ : !torch.vtensor <[4 ,4 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1784+ -> !torch.vtensor <[4 ,4 ],si32 >
1785+
1786+ // int64 splat (max int32) → int32 (trunc)
1787+ %int64_splat = torch.vtensor.literal (dense <2147483647 > : tensor <10 xsi64 >) : !torch.vtensor <[10 ],si64 >
1788+ // CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<2147483647> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1789+ %result3 = torch.aten.to.dtype %int64_splat , %int3 , %false , %false , %none
1790+ : !torch.vtensor <[10 ],si64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1791+ -> !torch.vtensor <[10 ],si32 >
1792+
1793+ // float32 splat → float64
1794+ %float32_splat = torch.vtensor.literal (dense <2.71828 > : tensor <5 x5 xf32 >) : !torch.vtensor <[5 ,5 ],f32 >
1795+ %int7 = torch.constant.int 7 // torch.float64
1796+ // CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1797+ %result4 = torch.aten.to.dtype %float32_splat , %int7 , %false , %false , %none
1798+ : !torch.vtensor <[5 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1799+ -> !torch.vtensor <[5 ,5 ],f64 >
1800+
1801+ // float64 splat → float16
1802+ %float64_splat = torch.vtensor.literal (dense <1.23456 > : tensor <3 x3 xf64 >) : !torch.vtensor <[3 ,3 ],f64 >
1803+ %int5 = torch.constant.int 5 // torch.float16
1804+ // CHECK: %[[R5:.*]] = torch.vtensor.literal({{.*}} : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1805+ %result5 = torch.aten.to.dtype %float64_splat , %int5 , %false , %false , %none
1806+ : !torch.vtensor <[3 ,3 ],f64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1807+ -> !torch.vtensor <[3 ,3 ],f16 >
1808+
1809+ // float32 splat → bfloat16
1810+ %float32_bf16 = torch.vtensor.literal (dense <-0.5 > : tensor <2 x2 xf32 >) : !torch.vtensor <[2 ,2 ],f32 >
1811+ %int15 = torch.constant.int 15 // torch.bfloat16
1812+ // CHECK: %[[R6:.*]] = torch.vtensor.literal({{.*}} : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1813+ %result6 = torch.aten.to.dtype %float32_bf16 , %int15 , %false , %false , %none
1814+ : !torch.vtensor <[2 ,2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1815+ -> !torch.vtensor <[2 ,2 ],bf16 >
1816+
1817+ // int32 splat → int64 (sign-extend)
1818+ %int32_ext = torch.vtensor.literal (dense <-1000 > : tensor <4 xsi32 >) : !torch.vtensor <[4 ],si32 >
1819+ %int4 = torch.constant.int 4 // torch.int64
1820+ // CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1821+ %result7 = torch.aten.to.dtype %int32_ext , %int4 , %false , %false , %none
1822+ : !torch.vtensor <[4 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1823+ -> !torch.vtensor <[4 ],si64 >
1824+
1825+ // int32 splat → int16 (trunc)
1826+ %int32_trunc = torch.vtensor.literal (dense <32000 > : tensor <3 xsi32 >) : !torch.vtensor <[3 ],si32 >
1827+ %int2 = torch.constant.int 2 // torch.int16
1828+ // CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<32000> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1829+ %result8 = torch.aten.to.dtype %int32_trunc , %int2 , %false , %false , %none
1830+ : !torch.vtensor <[3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1831+ -> !torch.vtensor <[3 ],si16 >
1832+
1833+ return %result1 , %result2 , %result3 , %result4 , %result5 , %result6 , %result7 , %result8
1834+ : !torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >
1835+ }
1836+
17651837// CHECK-LABEL: func.func @torch.aten.to.other$basic(
17661838// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
17671839// CHECK: %[[NONE:.*]] = torch.constant.none
0 commit comments