@@ -60,7 +60,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
6060
6161// -----
6262
63- // COM: Test while loop / tt.advance before tt.load (TODO)
63+ // COM: Test while loop / nested tt.advance
6464#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 4 ], warpsPerCTA = [32 , 1 ], order = [1 , 0 ]}>
6565#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
6666// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
@@ -99,8 +99,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
9999 // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]>
100100 %5 = tt.dot %4 , %cstB , %cst , inputPrecision = tf32 : tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
101101 %6 = ttg.convert_layout %5 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked1 >
102- // COM: TODO: support nested tt.advance
103- // %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
102+ // CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
103+ %7 = tt.advance %a_ptr_crt , [%c0_i32 , %c32_i32 ] : <tensor <256 x32 xf16 , #blocked1 >>
104104
105105 // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
106106 scf.yield %a_ptr_crt : !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>
0 commit comments