@@ -7810,6 +7810,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
7810
7810
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7811
7811
" return %0 : !torch.list<int>\n"
7812
7812
" }\n"
7813
+ " func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list<list<int>>) -> !torch.list<list<int>> {\n"
7814
+ " %true = torch.constant.bool true\n"
7815
+ " %int0 = torch.constant.int 0\n"
7816
+ " %int1 = torch.constant.int 1\n"
7817
+ " %0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7818
+ " %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7819
+ " %2 = torch.prim.If %1 -> (!torch.list<list<int>>) {\n"
7820
+ " %3 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7821
+ " torch.prim.If.yield %3 : !torch.list<list<int>>\n"
7822
+ " } else {\n"
7823
+ " %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7824
+ " %4 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7825
+ " %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7826
+ " %6 = torch.prim.Loop %5, %true, init(%3) {\n"
7827
+ " ^bb0(%arg1: !torch.int, %arg2: !torch.list<int>):\n"
7828
+ " %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7829
+ " %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7830
+ " %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7831
+ " torch.prim.Loop.condition %true, iter(%11 : !torch.list<int>)\n"
7832
+ " } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
7833
+ " %7 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7834
+ " %8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7835
+ " torch.prim.Loop %8, %true, init() {\n"
7836
+ " ^bb0(%arg1: !torch.int):\n"
7837
+ " %9 = torch.aten.append.t %7, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
7838
+ " torch.prim.Loop.condition %true, iter()\n"
7839
+ " } : (!torch.int, !torch.bool) -> ()\n"
7840
+ " torch.prim.If.yield %7 : !torch.list<list<int>>\n"
7841
+ " }\n"
7842
+ " return %2 : !torch.list<list<int>>\n"
7843
+ " }\n"
7813
7844
" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
7814
7845
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7815
7846
" return %0 : !torch.list<int>\n"
@@ -12556,6 +12587,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
12556
12587
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12557
12588
" return %0#1 : !torch.int\n"
12558
12589
" }\n"
12590
+ " func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.list<tuple<int, int>> {\n"
12591
+ " %true = torch.constant.bool true\n"
12592
+ " %int0 = torch.constant.int 0\n"
12593
+ " %0 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12594
+ " %1 = torch.prim.Loop %0, %true, init(%int0) {\n"
12595
+ " ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
12596
+ " %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12597
+ " %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12598
+ " %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
12599
+ " %7 = torch.prim.If %6 -> (!torch.int) {\n"
12600
+ " %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12601
+ " torch.prim.If.yield %8 : !torch.int\n"
12602
+ " } else {\n"
12603
+ " torch.prim.If.yield %arg2 : !torch.int\n"
12604
+ " }\n"
12605
+ " torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n"
12606
+ " } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
12607
+ " %2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>\n"
12608
+ " %3 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12609
+ " torch.prim.Loop %3, %true, init() {\n"
12610
+ " ^bb0(%arg1: !torch.int):\n"
12611
+ " %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12612
+ " %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12613
+ " %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12614
+ " %7 = torch.aten.append.t %2, %6 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>\n"
12615
+ " torch.prim.Loop.condition %true, iter()\n"
12616
+ " } : (!torch.int, !torch.bool) -> ()\n"
12617
+ " return %2 : !torch.list<tuple<int, int>>\n"
12618
+ " }\n"
12559
12619
" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
12560
12620
" %int7 = torch.constant.int 7\n"
12561
12621
" %int6 = torch.constant.int 6\n"
0 commit comments