@@ -4354,8 +4354,8 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43544354
43554355// -----
43564356// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4357- // CHECK-SAME: %[[VAL_0 :[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358- // CHECK: %[[VAL_1 :.*]] = torch_c.to_builtin_tensor %[[VAL_0 ]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4357+ // CHECK-SAME: %[[ARG_INPUT :[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358+ // CHECK: %[[INPUT_TENSOR :.*]] = torch_c.to_builtin_tensor %[[ARG_INPUT ]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
43594359// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
43604360// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
43614361// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
@@ -4365,17 +4365,17 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43654365// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
43664366// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
43674367// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4368- // CHECK: %[[VAL_11 :.*]] = tosa.transpose %[[VAL_1 ]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369- // CHECK: %[[VAL_12 :.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370- // CHECK: %[[VAL_13 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371- // CHECK: %[[VAL_14 :.*]] = tosa.pad %[[VAL_11 ]], %[[VAL_12 ]], %[[VAL_13 ]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372- // CHECK: %[[VAL_15 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373- // CHECK: %[[VAL_16 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374- // CHECK: %[[VAL_17 :.*]] = tosa.avg_pool2d %[[VAL_14 ]], %[[VAL_15 ]], %[[VAL_16 ]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375- // CHECK: %[[VAL_18 :.*]] = tosa.transpose %[[VAL_17 ]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376- // CHECK: %[[VAL_19 :.*]] = tensor.cast %[[VAL_18 ]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377- // CHECK: %[[VAL_20 :.*]] = torch_c.from_builtin_tensor %[[VAL_19 ]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378- // CHECK: return %[[VAL_20 ]] : !torch.vtensor<[1,192,35,35],f32>
4368+ // CHECK: %[[NHWC_TRANSPOSE :.*]] = tosa.transpose %[[INPUT_TENSOR ]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369+ // CHECK: %[[PADDING_SHAPE :.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370+ // CHECK: %[[PAD_FILL :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371+ // CHECK: %[[PADDED_NHWC :.*]] = tosa.pad %[[NHWC_TRANSPOSE ]], %[[PADDING_SHAPE ]], %[[PAD_FILL ]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372+ // CHECK: %[[AVG_POOL_LHS_ZP :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373+ // CHECK: %[[AVG_POOL_RHS_ZP :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374+ // CHECK: %[[AVG_POOL_RESULT :.*]] = tosa.avg_pool2d %[[PADDED_NHWC ]], %[[AVG_POOL_LHS_ZP ]], %[[AVG_POOL_RHS_ZP ]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375+ // CHECK: %[[RESULT_NCHW :.*]] = tosa.transpose %[[AVG_POOL_RESULT ]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376+ // CHECK: %[[RESULT_CAST :.*]] = tensor.cast %[[RESULT_NCHW ]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377+ // CHECK: %[[TORCH_RESULT :.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST ]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378+ // CHECK: return %[[TORCH_RESULT ]] : !torch.vtensor<[1,192,35,35],f32>
43794379// CHECK: }
43804380func.func @torch.aten.avg_pool2d.count_include_pad (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
43814381 %int0 = torch.constant.int 0
@@ -4394,30 +4394,30 @@ func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,
43944394
43954395// -----
43964396// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4397- // CHECK-SAME: %[[VAL_0 :[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398- // CHECK: %[[VAL_1 :.*]] = torch_c.to_builtin_tensor %[[VAL_0 ]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4397+ // CHECK-SAME: %[[ARG_INPUT :[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398+ // CHECK: %[[INPUT_TENSOR :.*]] = torch_c.to_builtin_tensor %[[ARG_INPUT ]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
43994399// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
44004400// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
44014401// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
44024402// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
44034403// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
44044404// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
44054405// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4406- // CHECK: %[[VAL_9 :.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407- // CHECK: %[[VAL_10 :.*]] = tosa.reshape %[[VAL_1 ]], %[[VAL_9 ]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408- // CHECK: %[[VAL_11 :.*]] = tosa.transpose %[[VAL_10 ]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409- // CHECK: %[[VAL_12 :.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410- // CHECK: %[[VAL_13 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411- // CHECK: %[[VAL_14 :.*]] = tosa.pad %[[VAL_11 ]], %[[VAL_12 ]], %[[VAL_13 ]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412- // CHECK: %[[VAL_15 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413- // CHECK: %[[VAL_16 :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414- // CHECK: %[[VAL_17 :.*]] = tosa.avg_pool2d %[[VAL_14 ]], %[[VAL_15 ]], %[[VAL_16 ]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415- // CHECK: %[[VAL_18 :.*]] = tosa.transpose %[[VAL_17 ]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416- // CHECK: %[[VAL_19 :.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417- // CHECK: %[[VAL_20 :.*]] = tosa.reshape %[[VAL_18 ]], %[[VAL_19 ]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418- // CHECK: %[[VAL_21 :.*]] = tensor.cast %[[VAL_20 ]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419- // CHECK: %[[VAL_22 :.*]] = torch_c.from_builtin_tensor %[[VAL_21 ]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420- // CHECK: return %[[VAL_22 ]] : !torch.vtensor<[1,512,10],f32>
4406+ // CHECK: %[[RESHAPE_SHAPE :.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407+ // CHECK: %[[RESHAPED_INPUT :.*]] = tosa.reshape %[[INPUT_TENSOR ]], %[[RESHAPE_SHAPE ]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408+ // CHECK: %[[TRANSPOSED_NHWC :.*]] = tosa.transpose %[[RESHAPED_INPUT ]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409+ // CHECK: %[[PADDING_SHAPE :.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410+ // CHECK: %[[PAD_FILL :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411+ // CHECK: %[[PADDED_NHWC :.*]] = tosa.pad %[[TRANSPOSED_NHWC ]], %[[PADDING_SHAPE ]], %[[PAD_FILL ]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412+ // CHECK: %[[AVG_POOL_LHS_ZP :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413+ // CHECK: %[[AVG_POOL_RHS_ZP :.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414+ // CHECK: %[[AVG_POOL_RESULT :.*]] = tosa.avg_pool2d %[[PADDED_NHWC ]], %[[AVG_POOL_LHS_ZP ]], %[[AVG_POOL_RHS_ZP ]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415+ // CHECK: %[[RESULT_NCHW :.*]] = tosa.transpose %[[AVG_POOL_RESULT ]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416+ // CHECK: %[[RESHAPE_BACK_SHAPE :.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417+ // CHECK: %[[RESHAPED_BACK :.*]] = tosa.reshape %[[RESULT_NCHW ]], %[[RESHAPE_BACK_SHAPE ]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418+ // CHECK: %[[RESULT_CAST :.*]] = tensor.cast %[[RESHAPED_BACK ]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419+ // CHECK: %[[TORCH_RESULT :.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST ]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420+ // CHECK: return %[[TORCH_RESULT ]] : !torch.vtensor<[1,512,10],f32>
44214421// CHECK: }
44224422func.func @torch.aten.avg_pool1d.count_include_pad (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
44234423 %int1 = torch.constant.int 1
0 commit comments