@@ -592,3 +592,131 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc
592592 %0 = torch.operator " onnx.ConvTranspose" (%arg0 , %arg1 ) {torch.onnx.pads = [1 : si64 , 2 : si64 , 1 : si64 , 2 : si64 ], torch.onnx.strides = [3 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,1 ,3 ,3 ],f32 >, !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >) -> !torch.vtensor <[1 ,2 ,7 ,3 ],f32 >
593593 return %0 : !torch.vtensor <[1 ,2 ,7 ,3 ],f32 >
594594 }
595+
596+ // CHECK-LABEL: @test_batchnorm_epsilon
597+ func.func @test_batchnorm_epsilon (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >, %arg1: !torch.vtensor <[3 ],f32 >, %arg2: !torch.vtensor <[3 ],f32 >, %arg3: !torch.vtensor <[3 ],f32 >, %arg4: !torch.vtensor <[3 ],f32 >) -> !torch.vtensor <[2 ,3 ,4 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 15 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
598+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
599+ // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
600+ // CHECK: %[[EPS:.*]] = torch.constant.float 0.0099999997764825821
601+ // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
602+ %0 = torch.operator " onnx.BatchNormalization" (%arg0 , %arg1 , %arg2 , %arg3 , %arg4 ) {torch.onnx.epsilon = 0.00999999977 : f32 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >) -> !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >
603+ return %0 : !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >
604+ }
605+
606+ // CHECK-LABEL: @test_batchnorm_example
607+ func.func @test_batchnorm_example (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >, %arg1: !torch.vtensor <[3 ],f32 >, %arg2: !torch.vtensor <[3 ],f32 >, %arg3: !torch.vtensor <[3 ],f32 >, %arg4: !torch.vtensor <[3 ],f32 >) -> !torch.vtensor <[2 ,3 ,4 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 15 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
608+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
609+ // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
610+ // CHECK: %[[EPS:.*]] = torch.constant.float 9.9999997473787516E-6
611+ // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
612+ %0 = torch.operator " onnx.BatchNormalization" (%arg0 , %arg1 , %arg2 , %arg3 , %arg4 ) : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],f32 >) -> !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >
613+ return %0 : !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >
614+ }
615+
616+ // CHECK-LABEL: @test_concat_1d_axis_0
617+ func.func @test_concat_1d_axis_0 (%arg0: !torch.vtensor <[2 ],f32 >, %arg1: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
618+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
619+ // CHECK: %[[DIM:.*]] = torch.constant.int 0
620+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
621+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ],f32 >, !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[4 ],f32 >
622+ return %0 : !torch.vtensor <[4 ],f32 >
623+ }
624+
625+ // CHECK-LABEL: @test_concat_1d_axis_negative_1
626+ func.func @test_concat_1d_axis_negative_1 (%arg0: !torch.vtensor <[2 ],f32 >, %arg1: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
627+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
628+ // CHECK: %[[DIM:.*]] = torch.constant.int -1
629+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
630+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ],f32 >, !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[4 ],f32 >
631+ return %0 : !torch.vtensor <[4 ],f32 >
632+ }
633+
634+ // CHECK-LABEL: @test_concat_2d_axis_0
635+ func.func @test_concat_2d_axis_0 (%arg0: !torch.vtensor <[2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
636+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
637+ // CHECK: %[[DIM:.*]] = torch.constant.int 0
638+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
639+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ],f32 >
640+ return %0 : !torch.vtensor <[4 ,2 ],f32 >
641+ }
642+
643+ // CHECK-LABEL: @test_concat_2d_axis_1
644+ func.func @test_concat_2d_axis_1 (%arg0: !torch.vtensor <[2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
645+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
646+ // CHECK: %[[DIM:.*]] = torch.constant.int 1
647+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
648+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ],f32 >
649+ return %0 : !torch.vtensor <[2 ,4 ],f32 >
650+ }
651+
652+ // CHECK-LABEL: @test_concat_2d_axis_negative_1
653+ func.func @test_concat_2d_axis_negative_1 (%arg0: !torch.vtensor <[2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
654+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
655+ // CHECK: %[[DIM:.*]] = torch.constant.int -1
656+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
657+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ],f32 >
658+ return %0 : !torch.vtensor <[2 ,4 ],f32 >
659+ }
660+
661+ // CHECK-LABEL: @test_concat_2d_axis_negative_2
662+ func.func @test_concat_2d_axis_negative_2 (%arg0: !torch.vtensor <[2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
663+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
664+ // CHECK: %[[DIM:.*]] = torch.constant.int -2
665+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
666+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ],f32 >
667+ return %0 : !torch.vtensor <[4 ,2 ],f32 >
668+ }
669+
670+ // CHECK-LABEL: @test_concat_3d_axis_0
671+ func.func @test_concat_3d_axis_0 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
672+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
673+ // CHECK: %[[DIM:.*]] = torch.constant.int 0
674+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
675+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ,2 ],f32 >
676+ return %0 : !torch.vtensor <[4 ,2 ,2 ],f32 >
677+ }
678+
679+ // CHECK-LABEL: @test_concat_3d_axis_1
680+ func.func @test_concat_3d_axis_1 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
681+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
682+ // CHECK: %[[DIM:.*]] = torch.constant.int 1
683+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
684+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ,2 ],f32 >
685+ return %0 : !torch.vtensor <[2 ,4 ,2 ],f32 >
686+ }
687+
688+ // CHECK-LABEL: @test_concat_3d_axis_2
689+ func.func @test_concat_3d_axis_2 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
690+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
691+ // CHECK: %[[DIM:.*]] = torch.constant.int 2
692+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
693+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = 2 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 >
694+ return %0 : !torch.vtensor <[2 ,2 ,4 ],f32 >
695+ }
696+
697+ // CHECK-LABEL: @test_concat_3d_axis_negative_1
698+ func.func @test_concat_3d_axis_negative_1 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
699+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
700+ // CHECK: %[[DIM:.*]] = torch.constant.int -1
701+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
702+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 >
703+ return %0 : !torch.vtensor <[2 ,2 ,4 ],f32 >
704+ }
705+
706+ // CHECK-LABEL: @test_concat_3d_axis_negative_2
707+ func.func @test_concat_3d_axis_negative_2 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
708+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
709+ // CHECK: %[[DIM:.*]] = torch.constant.int -2
710+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
711+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[2 ,4 ,2 ],f32 >
712+ return %0 : !torch.vtensor <[2 ,4 ,2 ],f32 >
713+ }
714+
715+ // CHECK-LABEL: @test_concat_3d_axis_negative_3
716+ func.func @test_concat_3d_axis_negative_3 (%arg0: !torch.vtensor <[2 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
717+ // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
718+ // CHECK: %[[DIM:.*]] = torch.constant.int -3
719+ // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
720+ %0 = torch.operator " onnx.Concat" (%arg0 , %arg1 ) {torch.onnx.axis = -3 : si64 } : (!torch.vtensor <[2 ,2 ,2 ],f32 >, !torch.vtensor <[2 ,2 ,2 ],f32 >) -> !torch.vtensor <[4 ,2 ,2 ],f32 >
721+ return %0 : !torch.vtensor <[4 ,2 ,2 ],f32 >
722+ }
0 commit comments