Skip to content

Commit 7cf1bb9

Browse files
committed
Add test with pytorch fix coming in
1 parent 443b5cd commit 7cf1bb9

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8183,6 +8183,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
81838183
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
81848184
" return %1 : !torch.tuple<list<int>, list<int>>\n"
81858185
" }\n"
8186+
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
8187+
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
8188+
" %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n"
8189+
" %none = torch.constant.none\n"
8190+
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n"
8191+
" %true = torch.constant.bool true\n"
8192+
" %int4 = torch.constant.int 4\n"
8193+
" %int3 = torch.constant.int 3\n"
8194+
" %int2 = torch.constant.int 2\n"
8195+
" %int0 = torch.constant.int 0\n"
8196+
" %int1 = torch.constant.int 1\n"
8197+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8198+
" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n"
8199+
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
8200+
" torch.prim.If.yield %true : !torch.bool\n"
8201+
" } else {\n"
8202+
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8203+
" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
8204+
" torch.prim.If.yield %12 : !torch.bool\n"
8205+
" }\n"
8206+
" torch.prim.If %2 -> () {\n"
8207+
" torch.prim.If.yield\n"
8208+
" } else {\n"
8209+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
8210+
" torch.prim.If.yield\n"
8211+
" }\n"
8212+
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
8213+
" %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n"
8214+
" torch.prim.If %4 -> () {\n"
8215+
" torch.prim.If.yield\n"
8216+
" } else {\n"
8217+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
8218+
" torch.prim.If.yield\n"
8219+
" }\n"
8220+
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8221+
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
8222+
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
8223+
" torch.prim.If %7 -> () {\n"
8224+
" torch.prim.If.yield\n"
8225+
" } else {\n"
8226+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8227+
" torch.prim.If.yield\n"
8228+
" }\n"
8229+
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8230+
" %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n"
8231+
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
8232+
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8233+
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8234+
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8235+
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8236+
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
8237+
" torch.prim.If.yield %15 : !torch.list<int>\n"
8238+
" } else {\n"
8239+
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8240+
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8241+
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8242+
" %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
8243+
" torch.prim.If.yield %14 : !torch.list<int>\n"
8244+
" }\n"
8245+
" return %10 : !torch.list<int>\n"
8246+
" }\n"
81868247
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
81878248
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
81888249
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
@@ -12133,6 +12194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1213312194
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
1213412195
" return %1 : !torch.tuple<int, int>\n"
1213512196
" }\n"
12197+
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
12198+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12199+
" return %0#1 : !torch.int\n"
12200+
" }\n"
1213612201
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
1213712202
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1213812203
" return %0#1 : !torch.int\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,15 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
10591059
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
10601060
return maxpool3d, indices
10611061

1062+
def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
1063+
assert (len(self) == 4 or len(self) == 3), "Input be of rank 3 or 4"
1064+
assert (len(output_size) == 2), "output_size must have 2 elements"
1065+
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
1066+
if len(self) == 4:
1067+
return [self[0], self[1], output_size[0], output_size[1]]
1068+
else:
1069+
return [self[0], output_size[0], output_size[1]]
1070+
10621071
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
10631072
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
10641073
assert (len(output_size) == 3), "output_size must have 3 elements"
@@ -3205,6 +3214,10 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
32053214
self_rank, self_dtype = self_rank_dtype
32063215
return self_dtype, torch.int64
32073216

3217+
def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
3218+
self_rank, self_dtype = self_rank_dtype
3219+
return self_dtype
3220+
32083221
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
32093222
self_rank, self_dtype = self_rank_dtype
32103223
return self_dtype

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,36 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
19881988
# ==============================================================================
19891989

19901990

1991+
class MaxUnpool2dModule(torch.nn.Module):
1992+
def __init__(self):
1993+
super().__init__()
1994+
1995+
@export
1996+
@annotate_args(
1997+
[
1998+
None,
1999+
([-1, -1, 2, 2], torch.float32, True),
2000+
([-1, -1, 2, 2], torch.int64, True),
2001+
]
2002+
)
2003+
def forward(self, x, indices):
2004+
return torch.ops.aten.max_unpool2d(x, indices, (4, 4), (2, 2), (0, 0))
2005+
2006+
2007+
@register_test_case(module_factory=lambda: MaxUnpool2dModule())
2008+
def MaxUnpool2dModule_basic(module, tu: TestUtils):
2009+
input = tu.rand(2, 2, 4, 4)
2010+
pool = torch.nn.MaxPool2d(
2011+
kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), return_indices=True
2012+
)
2013+
output, indices = pool(input)
2014+
2015+
module.forward(output, indices)
2016+
2017+
2018+
# ==============================================================================
2019+
2020+
19912021
class MaxUnpool3dModule(torch.nn.Module):
19922022
def __init__(self):
19932023
super().__init__()

0 commit comments

Comments
 (0)