@@ -7368,10 +7368,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
73687368 loc, Torch::ListType::get (Torch::IntType::get (context)),
73697369 ValueRange{constantOne});
73707370
7371- rewriter.replaceOpWithNewOp <AtenMaxPool1dWithIndicesOp>(
7372- op, op.getType (0 ), op.getType (1 ), input, kernelSizeList, strideList,
7373- paddingSizeList, dialationList,
7374- /* ceil_mode=*/ constantFalse);
7371+ if (op.getResult (1 ).use_empty ()) {
7372+ auto maxPool = rewriter.create <AtenMaxPool1dOp>(loc, op.getType (0 ), input, kernelSizeList,
7373+ strideList, paddingSizeList,
7374+ dialationList,
7375+ /* ceil_mode=*/ constantFalse);
7376+ rewriter.replaceOp (op, {maxPool.getResult (), Value ()});
7377+ } else {
7378+ auto maxPool = rewriter.create <AtenMaxPool1dWithIndicesOp>(loc, op.getType (0 ), op.getType (1 ), input, kernelSizeList,
7379+ strideList, paddingSizeList,
7380+ dialationList,
7381+ /* ceil_mode=*/ constantFalse);
7382+ rewriter.replaceOp (op, {maxPool.getResult (0 ), maxPool.getResult (1 )});
7383+ }
73757384 return success ();
73767385 }
73777386};
0 commit comments