@@ -8183,6 +8183,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
81838183" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
81848184" return %1 : !torch.tuple<list<int>, list<int>>\n"
81858185" }\n"
8186+ " func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
8187+ " %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
8188+ " %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n"
8189+ " %none = torch.constant.none\n"
8190+ " %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n"
8191+ " %true = torch.constant.bool true\n"
8192+ " %int4 = torch.constant.int 4\n"
8193+ " %int3 = torch.constant.int 3\n"
8194+ " %int2 = torch.constant.int 2\n"
8195+ " %int0 = torch.constant.int 0\n"
8196+ " %int1 = torch.constant.int 1\n"
8197+ " %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8198+ " %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n"
8199+ " %2 = torch.prim.If %1 -> (!torch.bool) {\n"
8200+ " torch.prim.If.yield %true : !torch.bool\n"
8201+ " } else {\n"
8202+ " %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8203+ " %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
8204+ " torch.prim.If.yield %12 : !torch.bool\n"
8205+ " }\n"
8206+ " torch.prim.If %2 -> () {\n"
8207+ " torch.prim.If.yield\n"
8208+ " } else {\n"
8209+ " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
8210+ " torch.prim.If.yield\n"
8211+ " }\n"
8212+ " %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
8213+ " %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n"
8214+ " torch.prim.If %4 -> () {\n"
8215+ " torch.prim.If.yield\n"
8216+ " } else {\n"
8217+ " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
8218+ " torch.prim.If.yield\n"
8219+ " }\n"
8220+ " %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8221+ " %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
8222+ " %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
8223+ " torch.prim.If %7 -> () {\n"
8224+ " torch.prim.If.yield\n"
8225+ " } else {\n"
8226+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8227+ " torch.prim.If.yield\n"
8228+ " }\n"
8229+ " %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8230+ " %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n"
8231+ " %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
8232+ " %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8233+ " %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8234+ " %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8235+ " %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8236+ " %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
8237+ " torch.prim.If.yield %15 : !torch.list<int>\n"
8238+ " } else {\n"
8239+ " %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8240+ " %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8241+ " %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8242+ " %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
8243+ " torch.prim.If.yield %14 : !torch.list<int>\n"
8244+ " }\n"
8245+ " return %10 : !torch.list<int>\n"
8246+ " }\n"
81868247" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
81878248" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
81888249" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
@@ -12133,6 +12194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1213312194" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
1213412195" return %1 : !torch.tuple<int, int>\n"
1213512196" }\n"
12197+ " func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
12198+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12199+ " return %0#1 : !torch.int\n"
12200+ " }\n"
1213612201" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
1213712202" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1213812203" return %0#1 : !torch.int\n"
0 commit comments