Skip to content

Commit 2b14305

Browse files
committed
Fold no-op reshape
1 parent 9938abf commit 2b14305

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11404,6 +11404,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
1140411404
printDefaultTorchOp(printer, *this, 2, 1);
1140511405
}
1140611406
}];
11407+
let hasFolder = 1;
1140711408
}
1140811409

1140911410
def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,16 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
22612261
});
22622262
}
22632263

2264+
//===----------------------------------------------------------------------===//
2265+
// AtenReshapeOp
2266+
//===----------------------------------------------------------------------===//
2267+
2268+
OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
2269+
if (getSelf().getType() == getType())
2270+
return getSelf();
2271+
return nullptr;
2272+
}
2273+
22642274
//===----------------------------------------------------------------------===//
22652275
// AtenSelectIntOp
22662276
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def emit_with_mutating_variants(key, **kwargs):
852852
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
853853
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
854854
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
855-
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
855+
emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True)
856856
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
857857
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
858858
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")

0 commit comments

Comments
 (0)