@@ -489,3 +489,42 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt
489489
490490 return %arg0 : !torch.vtensor <[2 ],f32 >
491491}
492+
493+ // CHECK-LABEL: func.func @unflat_shape_partial_dyn
494+ // CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768
495+ // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
496+ // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
497+ // CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
498+ // CHECK : } shapes {
499+ // CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
500+ // CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
501+ // CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
502+ // CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list<int>
503+ // CHECK : } : !torch.vtensor<[?,?,4,768],f32>
504+ func.func @unflat_shape_partial_dyn (%arg0: !torch.vtensor <[?,?,3072 ],f32 >) -> !torch.vtensor <[?,?,4 ,?],f32 > {
505+ %int768 = torch.constant.int 768
506+ %int3072 = torch.constant.int 3072
507+ %int0 = torch.constant.int 0
508+ %int3 = torch.constant.int 3
509+ %int1 = torch.constant.int 1
510+ %none = torch.constant.none
511+ %int -1 = torch.constant.int -1
512+ %int2 = torch.constant.int 2
513+ %int4 = torch.constant.int 4
514+ %0 = torch.prim.ListConstruct %int4 , %int -1 : (!torch.int , !torch.int ) -> !torch.list <int >
515+ %1 = torch.shape.calculate {
516+ %2 = torch.aten.unflatten.int %arg0 , %int2 , %0 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int , !torch.list <int > -> !torch.vtensor <[?,?,4 ,?],f32 >
517+ torch.shape.calculate.yield %2 : !torch.vtensor <[?,?,4 ,?],f32 >
518+ } shapes {
519+ %2 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int -> !torch.int
520+ %3 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int -> !torch.int
521+ %4 = torch.prim.ListConstruct %2 , %3 , %int3072 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
522+ %5 = torch.prim.ListConstruct %int4 , %int768 : (!torch.int , !torch.int ) -> !torch.list <int >
523+ %6 = torch.aten.slice.t %4 , %none , %int2 , %int1 : !torch.list <int >, !torch.none , !torch.int , !torch.int -> !torch.list <int >
524+ %7 = torch.aten.add.t %6 , %5 : !torch.list <int >, !torch.list <int > -> !torch.list <int >
525+ %8 = torch.aten.slice.t %4 , %int3 , %none , %int1 : !torch.list <int >, !torch.int , !torch.none , !torch.int -> !torch.list <int >
526+ %9 = torch.aten.add.t %7 , %8 : !torch.list <int >, !torch.list <int > -> !torch.list <int >
527+ torch.shape.calculate.yield.shapes %9 : !torch.list <int >
528+ } : !torch.vtensor <[?,?,4 ,?],f32 >
529+ return %1 : !torch.vtensor <[?,?,4 ,?],f32 >
530+ }
0 commit comments