Skip to content

Commit 9a72c65

Browse files
[MLIR][ONNX] Add OnnxToTorch support for BatchNormalization and Concat op.
This commit adds the OnnxToTorch support for BatchNormalization and Concat op. Signed-Off By: [email protected]
1 parent 85b86b3 commit 9a72c65

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
165165
binder.op, resultType, operand);
166166
return success();
167167
});
168+
patterns.onOp("BatchNormalization", 15,
169+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
170+
Torch::ValueTensorType resultType;
171+
Value input, weight, bias, runningMean, runningVar;
172+
bool training;
173+
float momentum, eps;
174+
if (binder.s64BoolAttr(training, "training_mode", 0))
175+
return failure();
176+
if (training) {
177+
// TODO: Add support for training = true
178+
return rewriter.notifyMatchFailure(
179+
binder.op, "unsupported conversion: training = true");
180+
}
181+
182+
if (binder.tensorOperandAtIndex(input, 0) ||
183+
binder.tensorOperandAtIndex(weight, 1) ||
184+
binder.tensorOperandAtIndex(bias, 2) ||
185+
binder.tensorOperandAtIndex(runningMean, 3) ||
186+
binder.tensorOperandAtIndex(runningVar, 4) ||
187+
binder.f32FloatAttr(momentum, "momentum", 0.9) ||
188+
binder.f32FloatAttr(eps, "epsilon", 1e-05) ||
189+
binder.tensorResultType(resultType))
190+
return failure();
191+
192+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
193+
binder.getLoc(), false);
194+
Value cstMomentum = rewriter.create<Torch::ConstantFloatOp>(
195+
binder.getLoc(), rewriter.getF64FloatAttr(momentum));
196+
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
197+
binder.getLoc(), rewriter.getF64FloatAttr(eps));
198+
199+
rewriter.replaceOpWithNewOp<Torch::AtenBatchNormOp>(
200+
binder.op, resultType, input, weight, bias, runningMean,
201+
runningVar, /*training=*/cstFalse, cstMomentum, cstEps,
202+
/*cudnn_enabled=*/cstFalse);
203+
return success();
204+
});
168205
patterns.onOp(
169206
"AveragePool", 19,
170207
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
@@ -426,6 +463,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
426463
}
427464
return failure();
428465
});
466+
patterns.onOp(
467+
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
468+
Torch::ValueTensorType resultType;
469+
SmallVector<Value> tensors;
470+
int64_t dim;
471+
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
472+
binder.s64IntegerAttr(dim, "axis", 0) ||
473+
binder.tensorResultType(resultType))
474+
return failure();
475+
Type listElemType =
476+
tensors[0]
477+
.getType()
478+
.cast<Torch::BaseTensorType>()
479+
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
480+
/*optionalDtype=*/nullptr);
481+
Type listType = Torch::ListType::get(listElemType);
482+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
483+
binder.op->getLoc(), listType, tensors);
484+
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
485+
binder.getLoc(), rewriter.getI64IntegerAttr(dim));
486+
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
487+
tensorList, cstDim);
488+
return success();
489+
});
429490
patterns.onOp(
430491
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
431492
std::string autoPad;

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,131 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc
592592
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32>
593593
return %0 : !torch.vtensor<[1,2,7,3],f32>
594594
}
595+
596+
// CHECK-LABEL: @test_batchnorm_epsilon
597+
func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
598+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
599+
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
600+
// CHECK: %[[EPS:.*]] = torch.constant.float 0.0099999997764825821
601+
// CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
602+
%0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 0.00999999977 : f32} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32>
603+
return %0 : !torch.vtensor<[2,3,4,5],f32>
604+
}
605+
606+
// CHECK-LABEL: @test_batchnorm_example
607+
func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
608+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
609+
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
610+
// CHECK: %[[EPS:.*]] = torch.constant.float 9.9999997473787516E-6
611+
// CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
612+
%0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32>
613+
return %0 : !torch.vtensor<[2,3,4,5],f32>
614+
}
615+
616+
// CHECK-LABEL: @test_concat_1d_axis_0
617+
func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
618+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
619+
// CHECK: %[[DIM:.*]] = torch.constant.int 0
620+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
621+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32>
622+
return %0 : !torch.vtensor<[4],f32>
623+
}
624+
625+
// CHECK-LABEL: @test_concat_1d_axis_negative_1
626+
func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
627+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
628+
// CHECK: %[[DIM:.*]] = torch.constant.int -1
629+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
630+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32>
631+
return %0 : !torch.vtensor<[4],f32>
632+
}
633+
634+
// CHECK-LABEL: @test_concat_2d_axis_0
635+
func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
636+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
637+
// CHECK: %[[DIM:.*]] = torch.constant.int 0
638+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
639+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32>
640+
return %0 : !torch.vtensor<[4,2],f32>
641+
}
642+
643+
// CHECK-LABEL: @test_concat_2d_axis_1
644+
func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
645+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
646+
// CHECK: %[[DIM:.*]] = torch.constant.int 1
647+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
648+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32>
649+
return %0 : !torch.vtensor<[2,4],f32>
650+
}
651+
652+
// CHECK-LABEL: @test_concat_2d_axis_negative_1
653+
func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
654+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
655+
// CHECK: %[[DIM:.*]] = torch.constant.int -1
656+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
657+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32>
658+
return %0 : !torch.vtensor<[2,4],f32>
659+
}
660+
661+
// CHECK-LABEL: @test_concat_2d_axis_negative_2
662+
func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
663+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
664+
// CHECK: %[[DIM:.*]] = torch.constant.int -2
665+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
666+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32>
667+
return %0 : !torch.vtensor<[4,2],f32>
668+
}
669+
670+
// CHECK-LABEL: @test_concat_3d_axis_0
671+
func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
672+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
673+
// CHECK: %[[DIM:.*]] = torch.constant.int 0
674+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
675+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32>
676+
return %0 : !torch.vtensor<[4,2,2],f32>
677+
}
678+
679+
// CHECK-LABEL: @test_concat_3d_axis_1
680+
func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
681+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
682+
// CHECK: %[[DIM:.*]] = torch.constant.int 1
683+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
684+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32>
685+
return %0 : !torch.vtensor<[2,4,2],f32>
686+
}
687+
688+
// CHECK-LABEL: @test_concat_3d_axis_2
689+
func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
690+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
691+
// CHECK: %[[DIM:.*]] = torch.constant.int 2
692+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
693+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32>
694+
return %0 : !torch.vtensor<[2,2,4],f32>
695+
}
696+
697+
// CHECK-LABEL: @test_concat_3d_axis_negative_1
698+
func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
699+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
700+
// CHECK: %[[DIM:.*]] = torch.constant.int -1
701+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
702+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32>
703+
return %0 : !torch.vtensor<[2,2,4],f32>
704+
}
705+
706+
// CHECK-LABEL: @test_concat_3d_axis_negative_2
707+
func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
708+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
709+
// CHECK: %[[DIM:.*]] = torch.constant.int -2
710+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
711+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32>
712+
return %0 : !torch.vtensor<[2,4,2],f32>
713+
}
714+
715+
// CHECK-LABEL: @test_concat_3d_axis_negative_3
716+
func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
717+
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
718+
// CHECK: %[[DIM:.*]] = torch.constant.int -3
719+
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
720+
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32>
721+
return %0 : !torch.vtensor<[4,2,2],f32>
722+
}

0 commit comments

Comments
 (0)