@@ -59,6 +59,64 @@ func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2
5959
6060// -----
6161
62+ // CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(
63+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>,
64+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>,
65+ // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) {
66+ func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g (%arg0: !torch.vtensor <[2 ,16 ,33 ,33 ],f32 >, %arg1: !torch.vtensor <[2 ,128 ,64 ,64 ],f32 >, %arg2: !torch.vtensor <[16 ,128 ,2 ,2 ],f32 >, %arg3: !torch.vtensor <[],f32 >) -> (!torch.vtensor <[2 ,128 ,64 ,64 ],f32 >, !torch.vtensor <[16 ],f32 >) {
67+ // CHECK: %[[CST1:.*]] = arith.constant 1 : index
68+ // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
69+ // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32>
70+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32>
71+ // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32>
72+ // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32>
73+ // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) {
74+ // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32):
75+ // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
76+ // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
77+ // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
78+ // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index
79+ // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index
80+ // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index
81+ // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32>
82+ // CHECK-NEXT: linalg.yield %[[EX]] : f32
83+ // CHECK-NEXT: } -> tensor<16x128x2x2xf32>
84+ // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x16x66x66xf32>
85+ // CHECK-NEXT: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x16x66x66xf32>) -> tensor<2x16x66x66xf32>
86+ // CHECK-NEXT: %[[SLICE:.*]] = tensor.insert_slice %[[T0]] into %[[SLICE_FILLED]][0, 0, 0, 0] [2, 16, 33, 33] [1, 1, 2, 2] : tensor<2x16x33x33xf32> into tensor<2x16x66x66xf32>
87+ // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32>
88+ // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32>
89+ // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d5 * 2 + d2, d6 * 2 + d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[SLICE]], %[[W_REV]] : tensor<2x16x66x66xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) {
90+ // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
91+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32
92+ // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
93+ // CHECK-NEXT: linalg.yield %[[ACC]] : f32
94+ // CHECK-NEXT: } -> tensor<2x128x64x64xf32>
95+ // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32>
96+ // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32>
97+ // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32>
98+ // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) {
99+ // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32):
100+ // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32
101+ // CHECK-NEXT: linalg.yield %[[B_RES]] : f32
102+ // CHECK-NEXT: } -> tensor<16xf32>
103+ // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32>
104+ // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>
105+ %true = torch.constant.bool true
106+ %int0 = torch.constant.int 0
107+ %false = torch.constant.bool false
108+ %int1 = torch.constant.int 1
109+ %int2 = torch.constant.int 2
110+ %0 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
111+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
112+ %2 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
113+ %3 = torch.prim.ListConstruct %true , %false , %true : (!torch.bool , !torch.bool , !torch.bool ) -> !torch.list <bool >
114+ %result0 , %result1 , %result2 = torch.aten.convolution_backward %arg0 , %arg1 , %arg2 , %0 , %1 , %1 , %1 , %false , %2 , %int1 , %3 : !torch.vtensor <[2 ,16 ,33 ,33 ],f32 >, !torch.vtensor <[2 ,128 ,64 ,64 ],f32 >, !torch.vtensor <[16 ,128 ,2 ,2 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int , !torch.list <bool > -> !torch.vtensor <[2 ,128 ,64 ,64 ],f32 >, !torch.none , !torch.vtensor <[16 ],f32 >
115+ return %result0 , %result2 : !torch.vtensor <[2 ,128 ,64 ,64 ],f32 >, !torch.vtensor <[16 ],f32 >
116+ }
117+
118+ // -----
119+
62120// CHECK-LABEL: func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g(
63121// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>,
64122// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>,
0 commit comments