Skip to content

Commit 617c1c7

Browse files
authored
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not required by the corresponding affine map. Fix the verifier to consider number of arguments for both.
1 parent b1413a6 commit 617c1c7

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
54055405
}
54065406

54075407
LogicalResult BindSymbolicShapeOp::verify() {
5408-
if (getShapeSymbols().empty())
5409-
return emitOpError() << "requires non-empty shapeSymbols";
5408+
if (getShapeSymbols().size() !=
5409+
getShapeExpressions().getValue().getNumSymbols())
5410+
return emitOpError()
5411+
<< "requires equal number of shape symbol args and symbol args to "
5412+
"the attached affine map, since they are 1:1 mapped";
54105413

54115414
for (auto symbol : getShapeSymbols()) {
54125415
Operation *definingOp = symbol.getDefiningOp();

test/Dialect/Torch/invalid.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
381381

382382
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
383383
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
384-
// expected-error @+1 {{op requires non-empty shapeSymbols}}
384+
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
385385
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
386386
return %arg0 : !torch.vtensor<[?],f32>
387387
}
388388

389389
// -----
390390

391+
// Verifier should not fail here since the op does not require shapeSymbols.
392+
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
393+
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
394+
return %arg0 : !torch.vtensor<[?],f32>
395+
}
396+
397+
// -----
398+
391399
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
392400
%int0 = torch.constant.int 0
393401
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}

0 commit comments

Comments
 (0)