Skip to content

Commit cc7e8e7

Browse files
committed
update
1 parent d49eabb commit cc7e8e7

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7368,10 +7368,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
73687368
loc, Torch::ListType::get(Torch::IntType::get(context)),
73697369
ValueRange{constantOne});
73707370

7371-
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
7372-
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
7373-
paddingSizeList, dialationList,
7374-
/*ceil_mode=*/constantFalse);
7371+
if (op.getResult(1).use_empty()) {
7372+
auto maxPool = rewriter.create<AtenMaxPool1dOp>(loc, op.getType(0), input, kernelSizeList,
7373+
strideList, paddingSizeList,
7374+
dialationList,
7375+
/*ceil_mode=*/constantFalse);
7376+
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
7377+
} else {
7378+
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(loc, op.getType(0), op.getType(1), input, kernelSizeList,
7379+
strideList, paddingSizeList,
7380+
dialationList,
7381+
/*ceil_mode=*/constantFalse);
7382+
rewriter.replaceOp(op, {maxPool.getResult(0), maxPool.getResult(1)});
7383+
}
73757384
return success();
73767385
}
73777386
};

0 commit comments

Comments
 (0)