@@ -3891,6 +3891,52 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
38913891
38923892// -----
38933893
3894+ // CHECK-LABEL: func.func @torch.aten.convolution$valid_padding(
3895+ // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
3896+ // CHECK: %[[INPUT_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[INPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32> -> tensor<1x1x5x5xf32>
3897+ // CHECK: %[[WEIGHT_CONST:.*]] = "tosa.const"() <{values = dense<-7.486820e-03> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
3898+ // CHECK: %[[BIAS_CONST:.*]] = "tosa.const"() <{values = dense<0.536443591> : tensor<1xf32>}> : () -> tensor<1xf32>
3899+ // CHECK: %[[STRIDE_H:.*]] = torch.constant.int 1
3900+ // CHECK: %[[STRIDE_W:.*]] = torch.constant.int 1
3901+ // CHECK: %[[STRIDES_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE_H]], %[[STRIDE_W]] : (!torch.int, !torch.int) -> !torch.list<int>
3902+ // CHECK: %[[PADDING_VAL:.*]] = torch.constant.int 0
3903+ // CHECK: %[[PADDING_LIST:.*]] = torch.prim.ListConstruct %[[PADDING_VAL]] : (!torch.int) -> !torch.list<int>
3904+ // CHECK: %[[DILATION_H:.*]] = torch.constant.int 1
3905+ // CHECK: %[[DILATION_W:.*]] = torch.constant.int 1
3906+ // CHECK: %[[DILATIONS_LIST:.*]] = torch.prim.ListConstruct %[[DILATION_H]], %[[DILATION_W]] : (!torch.int, !torch.int) -> !torch.list<int>
3907+ // CHECK: %[[GROUPS_VAL:.*]] = torch.constant.bool false
3908+ // CHECK: %[[OUTPUT_PADDING_VAL:.*]] = torch.constant.int 0
3909+ // CHECK: %[[OUTPUT_PADDING_LIST:.*]] = torch.prim.ListConstruct %[[OUTPUT_PADDING_VAL]] : (!torch.int) -> !torch.list<int>
3910+ // CHECK: %[[CONV_DIMENSIONS:.*]] = torch.constant.int 1
3911+ // CHECK: %[[WEIGHT_TRANSPOSED:.*]] = tosa.transpose %[[WEIGHT_CONST]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
3912+ // CHECK: %[[INPUT_TRANSPOSED:.*]] = tosa.transpose %[[INPUT_BUILTIN]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x1x5x5xf32>) -> tensor<1x5x5x1xf32>
3913+ // CHECK: %[[ZERO_BIAS_OP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3914+ // CHECK: %[[ZERO_BIAS_OP_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3915+ // CHECK: %[[CONV_RESULT_TOSA:.*]] = tosa.conv2d %[[INPUT_TRANSPOSED]], %[[WEIGHT_TRANSPOSED]], %[[BIAS_CONST]], %[[ZERO_BIAS_OP]], %[[ZERO_BIAS_OP_2]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x5x5x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32>
3916+ // CHECK: %[[OUTPUT_TRANSPOSED:.*]] = tosa.transpose %[[CONV_RESULT_TOSA]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x5x5x1xf32>) -> tensor<1x1x5x5xf32>
3917+ // CHECK: %[[OUTPUT_TENSOR:.*]] = torch_c.from_builtin_tensor %[[OUTPUT_TRANSPOSED]] : tensor<1x1x5x5xf32> -> !torch.vtensor<[1,1,5,5],f32>
3918+ // CHECK: return %[[OUTPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32>
3919+ func.func @torch.aten.convolution$valid_padding (%arg0: !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >) -> !torch.vtensor <[1 ,1 ,5 ,5 ],f32 > {
3920+ %0 = torch.vtensor.literal (dense <-7.486820e-03 > : tensor <1 x1 x1 x1 xf32 >) : !torch.vtensor <[1 ,1 ,1 ,1 ],f32 >
3921+ %1 = torch.vtensor.literal (dense <0.536443591 > : tensor <1 xf32 >) : !torch.vtensor <[1 ],f32 >
3922+ %int1 = torch.constant.int 1
3923+ %int1_0 = torch.constant.int 1
3924+ %2 = torch.prim.ListConstruct %int1 , %int1_0 : (!torch.int , !torch.int ) -> !torch.list <int >
3925+ %int0 = torch.constant.int 0
3926+ %3 = torch.prim.ListConstruct %int0 : (!torch.int ) -> !torch.list <int >
3927+ %int1_1 = torch.constant.int 1
3928+ %int1_2 = torch.constant.int 1
3929+ %4 = torch.prim.ListConstruct %int1_1 , %int1_2 : (!torch.int , !torch.int ) -> !torch.list <int >
3930+ %false = torch.constant.bool false
3931+ %int0_3 = torch.constant.int 0
3932+ %5 = torch.prim.ListConstruct %int0_3 : (!torch.int ) -> !torch.list <int >
3933+ %int1_4 = torch.constant.int 1
3934+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %false , %5 , %int1_4 : !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >, !torch.vtensor <[1 ,1 ,1 ,1 ],f32 >, !torch.vtensor <[1 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >
3935+ return %6 : !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >
3936+ }
3937+
3938+ // -----
3939+
38943940// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(
38953941// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1,56,56],f32>) -> !torch.vtensor<[1,1,27,27],f32> {
38963942// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,56,56],f32> -> tensor<1x1x56x56xf32>
0 commit comments