Skip to content

Commit 99115dc

Browse files
authored
[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)
Addresses <iree-org/iree#19262 (comment)>
1 parent 0913b96 commit 99115dc

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,16 +2593,22 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
25932593
// first the input tensor is flattened to 1d tensor and then the reduction
25942594
// happens on the 0th dimension.
25952595
if (isa<Torch::NoneType>(dim.getType())) {
2596-
BaseTensorType flattenType =
2597-
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
2598-
{kUnknownSize}, inputType.getOptionalDtype()));
2599-
Value zero =
2600-
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
2601-
Value end = rewriter.create<ConstantIntOp>(
2602-
loc, rewriter.getI64IntegerAttr(inputRank - 1));
2596+
Value zero = rewriter.create<ConstantIntOp>(loc, 0);
26032597
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
2604-
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
2605-
zero, end);
2598+
if (inputType.getSizes().size() > 1) {
2599+
int64_t flattenSize = Torch::kUnknownSize;
2600+
if (inputType.areAllSizesKnown()) {
2601+
flattenSize = 1;
2602+
for (int64_t sze : inputType.getSizes())
2603+
flattenSize *= sze;
2604+
}
2605+
auto flattenType = cast<BaseTensorType>(inputType.getWithSizesAndDtype(
2606+
{flattenSize}, inputType.getOptionalDtype()));
2607+
Value end = rewriter.create<ConstantIntOp>(
2608+
loc, rewriter.getI64IntegerAttr(inputRank - 1));
2609+
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
2610+
zero, end);
2611+
}
26062612
Value resultIndices =
26072613
rewriter
26082614
.create<DecompOpTy>(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,10 +545,6 @@
545545

546546
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
547547
"AddFloatIntModule_basic",
548-
"ArgmaxIntModule_basic",
549-
"ArgmaxIntModule_multiple_maxs",
550-
"ArgmaxKeepdimModule_basic",
551-
"ArgmaxModule_basic",
552548
"AtenKthvalueDynamicDimsModule_basic",
553549
"AtenKthvalueFloat64DynamicDimsModule_basic",
554550
"AtenKthvalueFloat64Module_basic",
@@ -618,9 +614,6 @@
618614
"AnyBoolFalseModule_basic",
619615
"AnyBoolTrueModule_basic",
620616
"ArangeStartOutViewModule_basic",
621-
"ArgminIntModule_basic",
622-
"ArgminIntModule_multiple_mins",
623-
"ArgminModule_basic",
624617
"AtenComplexImagModule_basic",
625618
"AtenComplexRealModule_basic",
626619
"AtenComplexViewModule_basic",

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
2525
return %0 : !torch.tensor
2626
}
2727

28+
// -----
29+
// CHECK-LABEL: func.func @argmax_rank_1
30+
// CHECK: %[[I0:.*]] = torch.constant.int 0
31+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
32+
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
33+
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
34+
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
35+
%none = torch.constant.none
36+
%false = torch.constant.bool false
37+
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
38+
return %7 : !torch.vtensor<[],si64>
39+
}
40+
2841
// -----
2942
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
3043
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {

0 commit comments

Comments
 (0)