@@ -40,6 +40,82 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc
4040 return %0 : !torch.vtensor <[3 ,4 ,5 ],f32 >
4141}
4242
43+ // -----
44+
45+ // CHECK-LABEL: func.func @test_gemm_default
46+ func.func @test_gemm_default (%arg0: !torch.vtensor <[3 ,5 ],f32 >, %arg1: !torch.vtensor <[5 ,4 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
47+ // CHECK: %[[I1:.+]] = torch.constant.int 1
48+ // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
49+ // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
50+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) : (!torch.vtensor <[3 ,5 ],f32 >, !torch.vtensor <[5 ,4 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
51+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
52+ }
53+
54+ // -----
55+
56+ // CHECK-LABEL: func.func @test_gemm_transposeA
57+ func.func @test_gemm_transposeA (%arg0: !torch.vtensor <[5 ,3 ],f32 >, %arg1: !torch.vtensor <[5 ,4 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
58+ // CHECK: %[[I0:.+]] = torch.constant.int 0
59+ // CHECK: %[[I1:.+]] = torch.constant.int 1
60+ // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I1]] : !torch.vtensor<[5,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,5],f32>
61+ // CHECK: %[[MM:.+]] = torch.aten.mm %[[TRANS]], %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
62+ // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
63+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) {torch.onnx.transA = 1 : si64 } : (!torch.vtensor <[5 ,3 ],f32 >, !torch.vtensor <[5 ,4 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
64+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
65+ }
66+
67+ // -----
68+
69+ // CHECK-LABEL: func.func @test_gemm_transposeB
70+ func.func @test_gemm_transposeB (%arg0: !torch.vtensor <[3 ,5 ],f32 >, %arg1: !torch.vtensor <[4 ,5 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
71+ // CHECK: %[[I0:.+]] = torch.constant.int 0
72+ // CHECK: %[[I1:.+]] = torch.constant.int 1
73+ // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg1, %[[I0]], %[[I1]] : !torch.vtensor<[4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,4],f32>
74+ // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %[[TRANS]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
75+ // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
76+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) {torch.onnx.transB = 1 : si64 } : (!torch.vtensor <[3 ,5 ],f32 >, !torch.vtensor <[4 ,5 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
77+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
78+ }
79+
80+ // -----
81+
82+ // CHECK-LABEL: func.func @test_gemm_alpha
83+ func.func @test_gemm_alpha (%arg0: !torch.vtensor <[3 ,5 ],f32 >, %arg1: !torch.vtensor <[5 ,4 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
84+ // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
85+ // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01
86+ // CHECK: torch.aten.add.Tensor %arg2, %[[MM]], %[[ALPHA]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
87+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) {torch.onnx.alpha = 5.000000e-01 : f32 } : (!torch.vtensor <[3 ,5 ],f32 >, !torch.vtensor <[5 ,4 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
88+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
89+ }
90+
91+ // -----
92+
93+ // CHECK-LABEL: func.func @test_gemm_beta
94+ func.func @test_gemm_beta (%arg0: !torch.vtensor <[3 ,5 ],f32 >, %arg1: !torch.vtensor <[5 ,4 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
95+ // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
96+ // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 5.000000e-01
97+ // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
98+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) {torch.onnx.beta = 5.000000e-01 : f32 } : (!torch.vtensor <[3 ,5 ],f32 >, !torch.vtensor <[5 ,4 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
99+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
100+ }
101+
102+ // -----
103+
104+ // CHECK-LABEL: func.func @test_gemm_alpha_beta
105+ func.func @test_gemm_alpha_beta (%arg0: !torch.vtensor <[3 ,5 ],f32 >, %arg1: !torch.vtensor <[5 ,4 ],f32 >, %arg2: !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
106+ // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
107+ // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
108+ // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
109+ // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01
110+ // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 2.500000e-01
111+ // CHECK-DAG: %[[MUL:.+]] = torch.aten.mul.Scalar %[[MM]], %[[ALPHA]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
112+ // CHECK: torch.aten.add.Tensor %[[MUL]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
113+ %0 = torch.operator " onnx.Gemm" (%arg0 , %arg1 , %arg2 ) {torch.onnx.alpha = 5.000000e-01 : f32 , torch.onnx.beta = 2.500000e-01 : f32 } : (!torch.vtensor <[3 ,5 ],f32 >, !torch.vtensor <[5 ,4 ],f32 >, !torch.vtensor <[1 ,4 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
114+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
115+ }
116+
117+ // -----
118+
43119// CHECK-LABEL: func.func @test_leaky_relu
44120func.func @test_leaky_relu (%arg0: !torch.vtensor <[3 ,4 ,5 ],f32 >) -> !torch.vtensor <[3 ,4 ,5 ],f32 > attributes {torch.onnx_meta.opset_version = 16 : si64 } {
45121 // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
0 commit comments