From f088a992496f84854e0dbe41c03895b6ec108da7 Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Mon, 9 Feb 2026 22:55:36 +0000 Subject: [PATCH 1/3] Fix pre-commit Signed-off-by: nithinsubbiah --- wave_lang/kernel/wave/utils/barriers_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/wave/utils/barriers_utils.py b/wave_lang/kernel/wave/utils/barriers_utils.py index d1079367a8..9e9f4c1afb 100644 --- a/wave_lang/kernel/wave/utils/barriers_utils.py +++ b/wave_lang/kernel/wave/utils/barriers_utils.py @@ -453,11 +453,10 @@ def filter_regions_with_barriers( and region.consumer is not None and not isinstance( get_custom(region.producer), (GatherToLDS, TensorLoadToLDS) - ) - ): - existing_barrier = is_barrier_between(region.producer, region.consumer) - if existing_barrier is not None: - continue + ): + existing_barrier = is_barrier_between(region.producer, region.consumer) + if existing_barrier is not None: + continue filtered_results.append(region) return filtered_results From 38e7382318a97ad51fe31c6bbd6bb203a64b3ba0 Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Mon, 9 Feb 2026 23:41:29 +0000 Subject: [PATCH 2/3] Addressed comments Signed-off-by: nithinsubbiah --- wave_lang/kernel/wave/utils/barriers_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/utils/barriers_utils.py b/wave_lang/kernel/wave/utils/barriers_utils.py index 9e9f4c1afb..d1079367a8 100644 --- a/wave_lang/kernel/wave/utils/barriers_utils.py +++ b/wave_lang/kernel/wave/utils/barriers_utils.py @@ -453,10 +453,11 @@ def filter_regions_with_barriers( and region.consumer is not None and not isinstance( get_custom(region.producer), (GatherToLDS, TensorLoadToLDS) - ): - existing_barrier = is_barrier_between(region.producer, region.consumer) - if existing_barrier is not None: - continue + ) + ): + existing_barrier = is_barrier_between(region.producer, region.consumer) + if existing_barrier is not None: + continue filtered_results.append(region) return filtered_results From dce0c4fb99ee711d6bcb735dc2459cc291616b74 Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Mon, 9 Feb 2026 23:02:28 +0000 Subject: [PATCH 3/3] [Compiler] Add 4 PP cluster with improved schedule reordering Signed-off-by: nithinsubbiah --- lit_tests/kernel/wave/gemm.py | 72 ++++-- wave_lang/kernel/wave/gather_to_shared.py | 1 - wave_lang/kernel/wave/schedule_reordering.py | 231 ++++++++++++++++-- .../kernel/wave/templates/reordered_gemm.py | 1 - wave_lang/kernel/wave/utils/graph_utils.py | 4 +- 5 files changed, 261 insertions(+), 48 deletions(-) diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 5a7bdaccd6..a893a3f73b 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -1754,51 +1754,79 @@ def test_gemm_two_async_cluster_pingpong(): # 1st cluster interleaved local and global reads. # 1st Cluster: First slice of Local read lhs and rhs - # CHECK-COUNT-2: vector.load %[[LHS_BUFFER:.+]][{{.*}}, %[[K0:.+]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K1:.+]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-4: vector.load %[[RHS_BUFFER:.+]][{{.*}}, %[[K0]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K1]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK: rocdl.sched.barrier + # CHECK-COUNT-4: vector.load %[[LHS_BUFFER:.+]][{{.*}}] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> + # CHECK-COUNT-8: vector.load %[[RHS_BUFFER:.+]][{{.*}}] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # 1st Cluster: Global load to shared - # CHECK-COUNT-4: amdgpu.gather_to_lds + # 1st Cluster: Global load to shared (1 gather_to_lds) + # CHECK: amdgpu.gather_to_lds # CHECK: rocdl.sched.barrier # First dot slice # CHECK: rocdl.s.setprio 1 - # CHECK-COUNT-16: amdgpu.mfma + # CHECK-COUNT-8: amdgpu.mfma # CHECK: rocdl.s.setprio 0 # CHECK: rocdl.sched.barrier - # CHECK-NEXT: amdgpu.memory_counter_wait load(4) - # CHECK-NEXT: rocdl.s.barrier + # CHECK: amdgpu.memory_counter_wait load(1) + # CHECK: rocdl.s.barrier + + # 2nd cluster: second slice of local read lhs + # CHECK-COUNT-4: vector.load %[[LHS_BUFFER]][{{.*}}] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # 2nd cluster second slice of local read lhs and rhs. - # CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K2:.+]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K3:.+]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K2]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K3]]] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> + # 2nd gather_to_lds + # CHECK: amdgpu.gather_to_lds # CHECK: rocdl.sched.barrier - # CHECK-NEXT: amdgpu.memory_counter_wait load(0) - # CHECK-NEXT: rocdl.s.barrier + # CHECK: amdgpu.memory_counter_wait load(2) + # CHECK: rocdl.s.barrier - # Second dot slice: + # Second dot slice # CHECK: rocdl.s.setprio 1 - # CHECK-COUNT-16: amdgpu.mfma + # CHECK-COUNT-8: amdgpu.mfma # CHECK: rocdl.s.setprio 0 # CHECK: rocdl.sched.barrier + # CHECK: rocdl.s.barrier - # Final LDS barrier to synchronize shared writes. - # CHECK: amdgpu.lds_barrier + # 3rd cluster: rhs local reads + # CHECK-COUNT-8: vector.load %[[RHS_BUFFER]][{{.*}}] : memref<128x64xf16, #gpu.address_space>, vector<4xf16> + + # 3rd gather_to_lds + # CHECK: amdgpu.gather_to_lds + # CHECK: rocdl.sched.barrier + # CHECK: amdgpu.memory_counter_wait load(3) + # CHECK: rocdl.s.barrier + + # Third dot slice + # CHECK: rocdl.s.setprio 1 + # CHECK-COUNT-8: amdgpu.mfma + # CHECK: rocdl.s.setprio 0 + # CHECK: rocdl.sched.barrier + # CHECK: rocdl.s.barrier + + # 4th gather_to_lds + # CHECK: amdgpu.gather_to_lds + # CHECK: rocdl.sched.barrier + # CHECK: rocdl.s.barrier + + # Fourth dot slice + # CHECK: rocdl.s.setprio 1 + # CHECK-COUNT-8: amdgpu.mfma + # CHECK: rocdl.s.setprio 0 + # CHECK: rocdl.sched.barrier + # CHECK: amdgpu.memory_counter_wait load(2) + # CHECK: rocdl.s.barrier + # CHECK: rocdl.sched.barrier # CHECK: scf.yield # CHECK: } - # Prologue + # Epilogue # cond_barrier on warp low to even out assymetry between 2 wave in same SIMD and Block. # CHECK: scf.if %[[WARP_LO]] { # CHECK-NEXT: rocdl.s.barrier # CHECK-NEXT: } + # Final LDS barrier before epilogue computations + # CHECK: amdgpu.lds_barrier + # This test that our stack is able to handle MMA layout with interleaved VGPR offsets/chunks diff --git a/wave_lang/kernel/wave/gather_to_shared.py b/wave_lang/kernel/wave/gather_to_shared.py index 4db0ddc8b6..e6070ef591 100644 --- a/wave_lang/kernel/wave/gather_to_shared.py +++ b/wave_lang/kernel/wave/gather_to_shared.py @@ -333,7 +333,6 @@ def emit_global_to_lds( for bound_expr, idx in zip(read.indexing_dims, nd_index): last = bound_expr == read.indexing_dims[-1] dim = infer_dim(bound_expr) - size = elements_per_thread if last else 1 stride = 1 write_index[dim] = IndexSequence(idx, size, stride) diff --git a/wave_lang/kernel/wave/schedule_reordering.py b/wave_lang/kernel/wave/schedule_reordering.py index 4f414ad1b9..6f1ba19610 100644 --- a/wave_lang/kernel/wave/schedule_reordering.py +++ b/wave_lang/kernel/wave/schedule_reordering.py @@ -359,9 +359,28 @@ def slice_mma(mma_nodes, lhs_nodes, rhs_nodes, num_slice): AssertionError: If slicing preconditions are not met """ reduction_dim = get_custom(mma_nodes[0]).reduction_dim + custom_mma = get_custom(mma_nodes[0]) + lhs_dim = set(get_custom(custom_mma.lhs).indexing_dims) + rhs_dim = set(get_custom(custom_mma.rhs).indexing_dims) + + filter_batch_dims = lambda dims: [ + dim for dim in dims if custom_mma.vector_shapes[dim] != 0 + ] + m_dims = filter_batch_dims(list(lhs_dim - rhs_dim)) + n_dims = filter_batch_dims(list(rhs_dim - lhs_dim)) + assert len(m_dims) == 1 + assert len(n_dims) == 1 + m_dim = m_dims[0] + n_dim = n_dims[0] + reduction_dim_ids = set( [get_custom(node).expanded_dims[reduction_dim] for node in mma_nodes] ) + + m_dim_ids = set([get_custom(node).expanded_dims[m_dim] for node in mma_nodes]) + + n_dim_ids = set([get_custom(node).expanded_dims[n_dim] for node in mma_nodes]) + # Checking that MMAs is valid. reduction_expand_size = len(reduction_dim_ids) assert reduction_expand_size >= num_slice and reduction_expand_size % num_slice == 0 @@ -381,6 +400,81 @@ def slice_mma(mma_nodes, lhs_nodes, rhs_nodes, num_slice): return sliced_mma_nodes, sliced_lhs_nodes, sliced_rhs_nodes +def slice_square_mma(mma_nodes, lhs_nodes, rhs_nodes): + """ + Slice MMA operations and their operand load nodes into equal-sized slices + based on the reduction dimension. + + This function groups MMA nodes and their corresponding LHS/RHS load nodes + into slices, where each slice contains operations for a contiguous range + of reduction dimension IDs. + + Args: + mma_nodes: List of MMA operation nodes + lhs_nodes: List of left-hand side load nodes + rhs_nodes: List of right-hand side load nodes + num_slice: Number of slices to create (must evenly divide reduction_expand_size) + + Returns: + Tuple of (sliced_mma_nodes, sliced_lhs_nodes, sliced_rhs_nodes), + where each element is a list of lists, one per slice + + Raises: + AssertionError: If slicing preconditions are not met + """ + custom_mma = get_custom(mma_nodes[0]) + lhs_dim = set(get_custom(custom_mma.lhs).indexing_dims) + rhs_dim = set(get_custom(custom_mma.rhs).indexing_dims) + + filter_batch_dims = lambda dims: [ + dim for dim in dims if custom_mma.vector_shapes[dim] != 0 + ] + m_dims = filter_batch_dims(list(lhs_dim - rhs_dim)) + n_dims = filter_batch_dims(list(rhs_dim - lhs_dim)) + assert len(m_dims) == 1 + assert len(n_dims) == 1 + m_dim = m_dims[0] + n_dim = n_dims[0] + + m_dim_ids = set([get_custom(node).expanded_dims[m_dim] for node in mma_nodes]) + + n_dim_ids = set([get_custom(node).expanded_dims[n_dim] for node in mma_nodes]) + + num_slice = 2 + assert len(m_dim_ids) >= num_slice and len(m_dim_ids) % num_slice == 0 + assert len(n_dim_ids) >= num_slice and len(n_dim_ids) % num_slice == 0 + + assert all(x in m_dim_ids for x in range(len(m_dim_ids))) + assert all(x in n_dim_ids for x in range(len(n_dim_ids))) + + # Slice each node list independently + sliced_mma_node_m0, sliced_mma_node_m1 = _slice_node_list( + mma_nodes, len(m_dim_ids), num_slice, m_dim + ) + sliced_mma_node_m0_n0, sliced_mma_node_m0_n1 = _slice_node_list( + sliced_mma_node_m0, len(m_dim_ids), num_slice, n_dim + ) + sliced_mma_node_m1_n0, sliced_mma_node_m1_n1 = _slice_node_list( + sliced_mma_node_m1, len(m_dim_ids), num_slice, n_dim + ) + + sliced_lhs_node_m0, sliced_lhs_node_m1 = _slice_node_list( + lhs_nodes, len(m_dim_ids), num_slice, m_dim + ) + sliced_rhs_node_n0, sliced_rhs_node_n1 = _slice_node_list( + rhs_nodes, len(m_dim_ids), num_slice, n_dim + ) + + sliced_mma_nodes = [ + [sliced_mma_node_m0_n0, sliced_mma_node_m0_n1], + [sliced_mma_node_m1_n0, sliced_mma_node_m1_n1], + ] + sliced_lhs_nodes = [sliced_lhs_node_m0, sliced_lhs_node_m1] + sliced_rhs_nodes = [sliced_rhs_node_n0, sliced_rhs_node_n1] + + return sliced_mma_nodes, sliced_lhs_nodes, sliced_rhs_nodes + + def slice_scale_mma( mma_nodes, lhs_nodes, rhs_nodes, lhs_scale_nodes, rhs_scale_nodes, num_slice ): @@ -491,8 +585,14 @@ def add_conditional_barriers_to_loop(custom_iterate, trace, hardware_constraint) ) # Generating and inserting cond_barriers to correct place in graph. + # SharedMemoryBarrier to only do lgmkcnt w/os sbarrier + with graph.inserting_before(custom_iterate.fx_node): + SharedMemoryBarrier().add_to_graph(graph, loc=custom_iterate.location) with graph.inserting_before(custom_iterate.fx_node): insert_cond_barrier(is_wave_hi, trace, graph, custom_iterate.location) + + with graph.inserting_after(custom_iterate.fx_node): + SharedMemoryBarrier().add_to_graph(graph, loc=custom_iterate.location) with graph.inserting_after(custom_iterate.fx_node): insert_cond_barrier(is_wave_lo, trace, graph, custom_iterate.location) return @@ -659,6 +759,15 @@ def transform_two_PP_clusters( return clusters +def slice_list(node_list, num_slice): + assert len(node_list) >= num_slice and len(node_list) % num_slice == 0 + stride = len(node_list) // num_slice + sliced_list = [] + for i in range(num_slice): + sliced_list.append(node_list[i * stride : (i + 1) * stride]) + return sliced_list + + def transform_async_two_PP_clusters( mma_nodes, local_load_lhs, @@ -667,9 +776,11 @@ def transform_async_two_PP_clusters( global_to_shared_rhs, ): num_slices = 2 - sliced_mma_nodes, sliced_local_load_lhs, sliced_local_load_rhs = slice_mma( - mma_nodes, local_load_lhs, local_load_rhs, num_slice=num_slices + sliced_mma_nodes, sliced_local_load_lhs, sliced_local_load_rhs = slice_square_mma( + mma_nodes, local_load_lhs, local_load_rhs ) + sliced_glds_lhs = slice_list(global_to_shared_lhs, num_slices) + sliced_glds_rhs = slice_list(global_to_shared_rhs, num_slices) # Check that we have valid slice size for local_loads and mmas. assert len(sliced_mma_nodes) == len(sliced_local_load_rhs) assert len(sliced_mma_nodes) == len(sliced_local_load_lhs) @@ -682,15 +793,11 @@ def transform_async_two_PP_clusters( # 1st cluster interleaved local and global reads. clusters.append(sliced_local_load_lhs[0]) clusters.append(sliced_local_load_rhs[0]) - barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) - barrier_op.location = context_location - clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[0])) - clusters.append(global_to_shared_lhs) - clusters.append(global_to_shared_rhs) + clusters.append(sliced_glds_lhs[0]) barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) barrier_op.location = context_location - clusters.append(insert_op_after(barrier_op, global_to_shared_rhs)) + clusters.append(insert_op_after(barrier_op, sliced_glds_lhs[0])) barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) barrier_op.location = context_location @@ -702,19 +809,16 @@ def transform_async_two_PP_clusters( # 2nd cluster mma_slice[0]. prio_op = SetWavePrio(1).add_to_graph(tmp_graph) prio_op.location = context_location - clusters.append(insert_op_before(prio_op, sliced_mma_nodes[0])) - clusters.append(sliced_mma_nodes[0]) + clusters.append(insert_op_before(prio_op, sliced_mma_nodes[0][0])) + clusters.append(sliced_mma_nodes[0][0]) prio_op = SetWavePrio(0).add_to_graph(tmp_graph) prio_op.location = context_location - clusters.append(insert_op_after(prio_op, sliced_mma_nodes[0])) + clusters.append(insert_op_after(prio_op, sliced_mma_nodes[0][0])) barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) barrier_op.location = context_location clusters.append(insert_op_after(barrier_op, clusters[-1].op)) - independent_global_count = len(global_to_shared_lhs + global_to_shared_rhs) - barrier_op = MemoryCounterWait(load=independent_global_count).add_to_graph( - tmp_graph - ) + barrier_op = MemoryCounterWait(load=len(sliced_glds_lhs[0])).add_to_graph(tmp_graph) barrier_op.location = context_location clusters.append(insert_op_after(barrier_op, clusters[-1].op)) @@ -727,12 +831,15 @@ def transform_async_two_PP_clusters( # 3rd cluster local load 2nd slice. clusters.append(sliced_local_load_lhs[1]) - clusters.append(sliced_local_load_rhs[1]) + clusters.append(sliced_glds_rhs[0]) + clusters.append(insert_op_after(barrier_op, sliced_glds_rhs[0])) barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) barrier_op.location = context_location - clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[1])) + clusters.append(insert_op_after(barrier_op, sliced_glds_rhs[0])) - barrier_op = MemoryCounterWait(load=0).add_to_graph(tmp_graph) + barrier_op = MemoryCounterWait( + load=len(sliced_glds_lhs[0] + sliced_glds_rhs[0]) + ).add_to_graph(tmp_graph) barrier_op.location = context_location clusters.append(insert_op_after(barrier_op, clusters[-1].op)) @@ -746,11 +853,87 @@ def transform_async_two_PP_clusters( # 4th cluster mma_slice[1]. prio_op = SetWavePrio(1).add_to_graph(tmp_graph) prio_op.location = context_location - clusters.append(insert_op_before(prio_op, sliced_mma_nodes[1])) - clusters.append(sliced_mma_nodes[1]) + clusters.append(insert_op_before(prio_op, sliced_mma_nodes[1][0])) + clusters.append(sliced_mma_nodes[1][0]) prio_op = SetWavePrio(0).add_to_graph(tmp_graph) prio_op.location = context_location - clusters.append(insert_op_after(prio_op, sliced_mma_nodes[1])) + clusters.append(insert_op_after(prio_op, sliced_mma_nodes[1][0])) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + clusters.append(sliced_local_load_rhs[1]) + clusters.append(sliced_glds_lhs[1]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_glds_lhs[1])) + + barrier_op = MemoryCounterWait( + load=len(sliced_glds_lhs[0] + sliced_glds_rhs[0] + sliced_glds_lhs[1]) + ).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + # 4th cluster mma_slice[1]. + prio_op = SetWavePrio(1).add_to_graph(tmp_graph) + prio_op.location = context_location + clusters.append(insert_op_before(prio_op, sliced_mma_nodes[0][1])) + clusters.append(sliced_mma_nodes[0][1]) + prio_op = SetWavePrio(0).add_to_graph(tmp_graph) + prio_op.location = context_location + clusters.append(insert_op_after(prio_op, sliced_mma_nodes[0][1])) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + clusters.append(sliced_glds_rhs[1]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_glds_rhs[1])) + + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + prio_op = SetWavePrio(1).add_to_graph(tmp_graph) + prio_op.location = context_location + clusters.append(insert_op_before(prio_op, sliced_mma_nodes[1][1])) + clusters.append(sliced_mma_nodes[1][1]) + prio_op = SetWavePrio(0).add_to_graph(tmp_graph) + prio_op.location = context_location + clusters.append(insert_op_after(prio_op, sliced_mma_nodes[1][1])) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = MemoryCounterWait( + load=len(sliced_glds_lhs[1] + sliced_glds_rhs[1]) + ).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) barrier_op.location = context_location clusters.append(insert_op_after(barrier_op, clusters[-1].op)) @@ -1019,7 +1202,11 @@ def schedule_reordering( global_to_shared_lhs, global_to_shared_rhs, ) - insert_prefetch_loop_barriers(custom_iterate, graph, clusters) + # Inserting prefetch loop manually. + with custom_iterate.graph.inserting_before(custom_iterate.fx_node): + SharedMemoryBarrier().add_to_graph( + custom_iterate.graph, loc=custom_iterate.location + ) elif reorder_strategy == SchedReorderStrategy.MXFP4_PP_CLUSTER: clusters = transform_MXFP4_PP_clusters( mma_nodes, diff --git a/wave_lang/kernel/wave/templates/reordered_gemm.py b/wave_lang/kernel/wave/templates/reordered_gemm.py index 1a67982e21..ee940e7ed0 100644 --- a/wave_lang/kernel/wave/templates/reordered_gemm.py +++ b/wave_lang/kernel/wave/templates/reordered_gemm.py @@ -25,7 +25,6 @@ def get_reordered_matmul( input_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float32, ): - # Initializing dtypes input_wtype = torch_dtype_to_wave(input_dtype) output_wtype = torch_dtype_to_wave(output_dtype) diff --git a/wave_lang/kernel/wave/utils/graph_utils.py b/wave_lang/kernel/wave/utils/graph_utils.py index 71698cb57b..935e2221a2 100644 --- a/wave_lang/kernel/wave/utils/graph_utils.py +++ b/wave_lang/kernel/wave/utils/graph_utils.py @@ -34,14 +34,14 @@ GetResult, IterArg, Iterate, + MemoryCounterWaitBarrier, NestedRegionOp, Output, Placeholder, SharedMemoryBarrier, - TopkOp, SharedMemoryBarrierSignal, SharedMemoryBarrierWait, - MemoryCounterWaitBarrier, + TopkOp, Write, get_custom, )