@@ -85,6 +85,32 @@ func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor
8585
8686// -----
8787
88+ // CHECK-LABEL: @cast_int_float_static
89+ func.func @cast_int_float_static (%arg0 : !torch.vtensor <[5 ,?,?],f32 >) -> !torch.vtensor <[3 ],f32 > {
90+ // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
91+ // CHECK: %[[FLOAT2:.*]] = torch.constant.float 2.000000e+00
92+ // CHECK: %[[FLOAT3:.*]] = torch.constant.float 3.000000e+00
93+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[FLOAT1:.*]], %[[FLOAT2:.*]], %[[FLOAT3:.*]] : (!torch.float, !torch.float, !torch.float) -> !torch.list<float>
94+ // CHECK: %[[NONE:.*]] = torch.constant.none
95+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
96+ // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<float>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],f32>
97+ // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],f32>
98+ %int6 = torch.constant.int 6
99+ %false = torch.constant.bool false
100+ %none = torch.constant.none
101+ %shape = torch.vtensor.literal (dense <[1 ,2 ,3 ]> : tensor <3 xsi64 >) : !torch.vtensor <[3 ],si64 >
102+ %cast_shape = torch.aten.to.dtype %shape , %int6 , %false , %false , %none : !torch.vtensor <[3 ],si64 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[3 ],f32 >
103+ %dim = torch.constant.int 0
104+ %idx0 = torch.vtensor.literal (dense <0 > : tensor <si64 >) : !torch.vtensor <[],si64 >
105+ %select0 = torch.aten.index_select %cast_shape , %dim , %idx0 : !torch.vtensor <[3 ],f32 >, !torch.int , !torch.vtensor <[],si64 > -> !torch.vtensor <[],f32 >
106+ %item0 = torch.aten.item %select0 : !torch.vtensor <[],f32 > -> !torch.float
107+ %item_int0 = torch.aten.Int.Scalar %item0 : !torch.float -> !torch.int
108+ %list = torch.prim.ListConstruct %item_int0 : (!torch.int ) -> !torch.list <int >
109+ return %cast_shape : !torch.vtensor <[3 ],f32 >
110+ }
111+
112+ // -----
113+
88114// CHECK-LABEL: @shape_as_tensor_dim_item
89115func.func @shape_as_tensor_dim_item (%arg0 : !torch.vtensor <[5 ,?,?],f32 >) -> !torch.int {
90116 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
0 commit comments