Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions lit_tests/kernel/wave/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,51 +1754,79 @@ def test_gemm_two_async_cluster_pingpong():
# 1st cluster interleaved local and global reads.

# 1st Cluster: First slice of Local read lhs and rhs
# CHECK-COUNT-2: vector.load %[[LHS_BUFFER:.+]][{{.*}}, %[[K0:.+]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K1:.+]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-4: vector.load %[[RHS_BUFFER:.+]][{{.*}}, %[[K0]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K1]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: rocdl.sched.barrier
# CHECK-COUNT-4: vector.load %[[LHS_BUFFER:.+]][{{.*}}] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-8: vector.load %[[RHS_BUFFER:.+]][{{.*}}] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>

# 1st Cluster: Global load to shared
# CHECK-COUNT-4: amdgpu.gather_to_lds
# 1st Cluster: Global load to shared (1 gather_to_lds)
# CHECK: amdgpu.gather_to_lds
# CHECK: rocdl.sched.barrier

# First dot slice
# CHECK: rocdl.s.setprio 1
# CHECK-COUNT-16: amdgpu.mfma
# CHECK-COUNT-8: amdgpu.mfma
# CHECK: rocdl.s.setprio 0
# CHECK: rocdl.sched.barrier
# CHECK-NEXT: amdgpu.memory_counter_wait load(4)
# CHECK-NEXT: rocdl.s.barrier
# CHECK: amdgpu.memory_counter_wait load(1)
# CHECK: rocdl.s.barrier

# 2nd cluster: second slice of local read lhs
# CHECK-COUNT-4: vector.load %[[LHS_BUFFER]][{{.*}}] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>

# 2nd cluster second slice of local read lhs and rhs.
# CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K2:.+]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-2: vector.load %[[LHS_BUFFER]][{{.*}}, %[[K3:.+]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K2]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-4: vector.load %[[RHS_BUFFER]][{{.*}}, %[[K3]]] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# 2nd gather_to_lds
# CHECK: amdgpu.gather_to_lds
# CHECK: rocdl.sched.barrier
# CHECK-NEXT: amdgpu.memory_counter_wait load(0)
# CHECK-NEXT: rocdl.s.barrier
# CHECK: amdgpu.memory_counter_wait load(2)
# CHECK: rocdl.s.barrier

# Second dot slice:
# Second dot slice
# CHECK: rocdl.s.setprio 1
# CHECK-COUNT-16: amdgpu.mfma
# CHECK-COUNT-8: amdgpu.mfma
# CHECK: rocdl.s.setprio 0
# CHECK: rocdl.sched.barrier
# CHECK: rocdl.s.barrier

# Final LDS barrier to synchronize shared writes.
# CHECK: amdgpu.lds_barrier
# 3rd cluster: rhs local reads
# CHECK-COUNT-8: vector.load %[[RHS_BUFFER]][{{.*}}] : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<4xf16>

# 3rd gather_to_lds
# CHECK: amdgpu.gather_to_lds
# CHECK: rocdl.sched.barrier
# CHECK: amdgpu.memory_counter_wait load(3)
# CHECK: rocdl.s.barrier

# Third dot slice
# CHECK: rocdl.s.setprio 1
# CHECK-COUNT-8: amdgpu.mfma
# CHECK: rocdl.s.setprio 0
# CHECK: rocdl.sched.barrier
# CHECK: rocdl.s.barrier

# 4th gather_to_lds
# CHECK: amdgpu.gather_to_lds
# CHECK: rocdl.sched.barrier
# CHECK: rocdl.s.barrier

# Fourth dot slice
# CHECK: rocdl.s.setprio 1
# CHECK-COUNT-8: amdgpu.mfma
# CHECK: rocdl.s.setprio 0
# CHECK: rocdl.sched.barrier
# CHECK: amdgpu.memory_counter_wait load(2)
# CHECK: rocdl.s.barrier
# CHECK: rocdl.sched.barrier
# CHECK: scf.yield
# CHECK: }

# Prologue
# Epilogue

# cond_barrier on warp low to even out assymetry between 2 wave in same SIMD and Block.
# CHECK: scf.if %[[WARP_LO]] {
# CHECK-NEXT: rocdl.s.barrier
# CHECK-NEXT: }

# Final LDS barrier before epilogue computations
# CHECK: amdgpu.lds_barrier


# This test that our stack is able to handle MMA layout with interleaved VGPR offsets/chunks

Expand Down
1 change: 0 additions & 1 deletion wave_lang/kernel/wave/gather_to_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def emit_global_to_lds(
for bound_expr, idx in zip(read.indexing_dims, nd_index):
last = bound_expr == read.indexing_dims[-1]
dim = infer_dim(bound_expr)

size = elements_per_thread if last else 1
stride = 1
write_index[dim] = IndexSequence(idx, size, stride)
Expand Down
Loading
Loading