@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
22652265
22662266// -----
22672267
2268- func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2269- %int0 = torch.constant.int 0
2270- %int1 = torch.constant.int 1
2271- %int3 = torch.constant.int 3
2272- %false = torch.constant.bool false
2273- %count_include_pad = torch.constant.bool true
2274- %divisor_override = torch.constant.none
2275-
2276- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2277- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2278- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2279- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2281- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2282- }
2283-
2284- // -----
2285-
22862268func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
22872269 %int0 = torch.constant.int 0
22882270 %int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28022784
28032785// -----
28042786
2805- func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
2806- %int1 = torch.constant.int 1
2807- %int3 = torch.constant.int 3
2808- %false = torch.constant.bool false
2809- %count_include_pad = torch.constant.bool true
2810- %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
2811- %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2812- %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2813- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814- %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
2815- return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
2816- }
2817-
2818- // -----
2819-
28202787// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28212788// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28222789// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4218,3 +4185,82 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
42184185 %4 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %1 , %2 , %false , %3 , %int1 : !torch.vtensor <[2 ,2 ,6 ,6 ],si8 >, !torch.vtensor <[8 ,2 ,3 ,3 ],si8 >, !torch.vtensor <[8 ],si32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
42194186 return %4 : !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
42204187 }
4188+
4189+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4190+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4191+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4192+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4193+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4194+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4195+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4196+ // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4197+ // CHECK: %[[VAL_7:.*]] = torch.constant.none
4198+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4199+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4200+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4201+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
4202+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4203+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
4204+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
4205+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4206+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4207+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4208+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4209+ // CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4210+ // CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4211+ // CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4212+ // CHECK: }
4213+ func.func @torch.aten.avg_pool2d.count_include_pad (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
4214+ %int0 = torch.constant.int 0
4215+ %int1 = torch.constant.int 1
4216+ %int3 = torch.constant.int 3
4217+ %false = torch.constant.bool false
4218+ %count_include_pad = torch.constant.bool true
4219+ %divisor_override = torch.constant.none
4220+
4221+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
4222+ %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4223+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4224+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4225+ return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4226+ }
4227+
4228+ // -----
4229+
4230+ // CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4231+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4232+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4233+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4234+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4235+ // CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4236+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4237+ // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4238+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4239+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4240+ // CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4241+ // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4242+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4243+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4244+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
4245+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
4246+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4247+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4248+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4249+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4250+ // CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4251+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4252+ // CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4253+ // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4254+ // CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4255+ // CHECK: }
4256+ func.func @torch.aten.avg_pool1d.count_include_pad (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
4257+ %int1 = torch.constant.int 1
4258+ %int3 = torch.constant.int 3
4259+ %false = torch.constant.bool false
4260+ %count_include_pad = torch.constant.bool true
4261+ %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
4262+ %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4263+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4264+ %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
4265+ return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
4266+ }
0 commit comments