Skip to content

Commit a33d123

Browse files
authored
[onnx] Fix onnx.Shape lowering with scalar input (#3716)
Address nod-ai/SHARK-ModelDev#826
1 parent 9938abf commit a33d123

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,29 +1662,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
16621662
auto shapeType = Torch::ValueTensorType::get(
16631663
binder.op->getContext(), SmallVector<int64_t>{inputRank},
16641664
resultType.getOptionalDtype());
1665-
16661665
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
16671666
binder.getLoc(), shapeType, operand);
16681667

1668+
if (inputRank == 0) {
1669+
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
1670+
binder.op, resultType, shape);
1671+
return success();
1672+
}
1673+
16691674
if (start == 0 && end == -1) {
16701675
rewriter.replaceOp(binder.op, shape);
16711676
return success();
16721677
}
16731678

16741679
Value sv = rewriter.create<Torch::ConstantIntOp>(
16751680
binder.getLoc(), rewriter.getI64IntegerAttr(start));
1676-
16771681
Value ev = rewriter.create<Torch::ConstantIntOp>(
16781682
binder.getLoc(), rewriter.getI64IntegerAttr(end));
1679-
16801683
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
1681-
16821684
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
16831685

1684-
shape = rewriter.create<Torch::AtenSliceTensorOp>(
1685-
binder.getLoc(), resultType, shape, dim, sv, ev, step);
1686-
1687-
rewriter.replaceOp(binder.op, shape);
1686+
rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
1687+
binder.op, resultType, shape, dim, sv, ev, step);
16881688
return success();
16891689
});
16901690

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,6 +2833,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>)
28332833
return %0 : !torch.vtensor<[1],si64>
28342834
}
28352835

2836+
// -----
2837+
2838+
// CHECK-LABEL: func.func @test_shape_scalar
2839+
func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
2840+
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64>
2841+
// CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
2842+
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64>
2843+
return %0: !torch.vtensor<[?],si64>
2844+
}
28362845

28372846
// -----
28382847

0 commit comments

Comments
 (0)