Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
78d3927
pass stub
Hardcode84 Dec 6, 2025
ad20fc1
del
Hardcode84 Jan 9, 2026
ea23eb6
del ssa dep
Hardcode84 Jan 9, 2026
e8a2394
tensor lod
Hardcode84 Jan 9, 2026
48215d9
del
Hardcode84 Jan 9, 2026
0eb692a
del test
Hardcode84 Jan 9, 2026
5520e64
test
Hardcode84 Jan 9, 2026
0204e98
test
Hardcode84 Jan 9, 2026
08febd1
test
Hardcode84 Jan 9, 2026
9e20f10
loop test
Hardcode84 Jan 9, 2026
58c716f
double buffer
Hardcode84 Jan 9, 2026
fcd3b4f
triple buffering
Hardcode84 Jan 9, 2026
c8ec468
only track tensor ops
Hardcode84 Jan 9, 2026
8e4a1b9
comments
Hardcode84 Jan 10, 2026
efd0a97
list of values
Hardcode84 Jan 10, 2026
65e2987
select
Hardcode84 Jan 10, 2026
adc37c5
cleanup
Hardcode84 Jan 10, 2026
345cf04
select tests
Hardcode84 Jan 10, 2026
66f9850
desc
Hardcode84 Jan 10, 2026
64358d3
fix
Hardcode84 Jan 11, 2026
9ef94ce
barrier wait
Hardcode84 Jan 11, 2026
c21e208
integrate
Hardcode84 Jan 12, 2026
c830748
update desc
Hardcode84 Jan 12, 2026
51de1ba
fix deduplication
Hardcode84 Jan 13, 2026
8d013b1
nicer print
Hardcode84 Jan 13, 2026
f017374
nicer print
Hardcode84 Jan 13, 2026
3616365
better barrier handling
Hardcode84 Jan 13, 2026
598edc4
remove redundant requirements
Hardcode84 Jan 13, 2026
bca2456
adapt to api changes
Hardcode84 Jan 29, 2026
ee1f0c3
skip tensor_cnt
Hardcode84 Jan 29, 2026
659bbdf
fix tensor desc propagation
Hardcode84 Jan 29, 2026
95e665f
update tensor count lowering
Hardcode84 Jan 29, 2026
a3b891a
disable manual tensor count in schedule
Hardcode84 Jan 29, 2026
6fe7081
update test
Hardcode84 Jan 29, 2026
6ab5f81
test water backend
Hardcode84 Jan 29, 2026
41d011e
update lit tests
Hardcode84 Jan 29, 2026
2e700cb
optional tensor waitcount
Hardcode84 Jan 29, 2026
8807ee4
update test
Hardcode84 Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def schedule_ops(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
print(schedule_ops.asm)

# CHECK-LABEL: func.func @schedule_ops
# CHECK: rocdl.s.wait.tensorcnt 0
# CHECK: amdgpu.memory_counter_wait tensor(0)
# CHECK: rocdl.s.wait.dscnt 0
# CHECK: rocdl.s.barrier.signal id = -1
# CHECK: rocdl.s.barrier.wait id = -1
Expand Down
6 changes: 3 additions & 3 deletions lit_tests/kernel/wave/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def test_gemm_four_stage_global_to_lds():
# Verify prologue stores to shared memory
# CHECK: amdgpu.tensor_load_to_lds

# CHECK: rocdl.s.wait.tensorcnt 0
# CHECK: amdgpu.memory_counter_wait tensor(0)
# CHECK: rocdl.s.wait.dscnt 0
# CHECK: rocdl.s.barrier.signal id = -1
# CHECK: rocdl.s.barrier.wait id = -1
Expand All @@ -1440,7 +1440,7 @@ def test_gemm_four_stage_global_to_lds():
# Verify WMMA exists
# CHECK: rocdl.wmma.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}

# CHECK: rocdl.s.wait.tensorcnt 0
# CHECK: amdgpu.memory_counter_wait tensor(0)
# CHECK: rocdl.s.wait.dscnt 0
# CHECK: rocdl.s.barrier.signal id = -1
# CHECK: rocdl.s.barrier.wait id = -1
Expand All @@ -1459,7 +1459,7 @@ def test_gemm_four_stage_global_to_lds():
# Epilogue:
# CHECK: rocdl.wmma.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}

# CHECK: rocdl.s.wait.tensorcnt 0
# CHECK: amdgpu.memory_counter_wait tensor(0)
# CHECK: rocdl.s.wait.dscnt 0
# CHECK: rocdl.s.barrier.signal id = -1
# CHECK: rocdl.s.barrier.wait id = -1
Expand Down
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def mma(

### resource provider
# CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]]
# CHECK: rocdl.s.wait.tensorcnt 0
# CHECK: amdgpu.memory_counter_wait tensor(0)
# CHECK: rocdl.s.wait.dscnt 0
# CHECK: rocdl.s.barrier.signal id = -1

Expand Down
17 changes: 12 additions & 5 deletions tests/kernel/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
asm = gemm.asm

assert (
"wait.tensorcnt" in asm
"memory_counter_wait tensor" in asm
), "tensor waitcnts are not found in asm: required for tensor load instructions."

validate_gemm_result(a, b, c, options)
Expand Down Expand Up @@ -3385,7 +3385,10 @@ def testSpecializeGemm(
@require_gfx1250
@pytest.mark.parametrize("shape", [(1024, 1024, 1024)])
@pytest.mark.parametrize("mfma_variant", [MMAType.GFX1250_F32_16x16x32_F16])
def test_gfx1250_tbuf_gemm(shape: tuple[int], mfma_variant: MMAType):
@use_water_backend_bool("use_water_backend")
def test_gfx1250_tbuf_gemm(
shape: tuple[int, int, int], mfma_variant: MMAType, use_water_backend: bool
):
gemm, options = get_tagged_BxA_T_gemm(
shape=shape,
block_shape=(256, 256, 64),
Expand All @@ -3395,8 +3398,11 @@ def test_gfx1250_tbuf_gemm(shape: tuple[int], mfma_variant: MMAType):
compile_to_mlir=False,
)

schedule = get_gfx1250_tbuf_gemm_schedule()
schedule = get_gfx1250_tbuf_gemm_schedule(
insert_tensor_waitcount=not use_water_backend
)
options = set_default_run_config(options)
options.use_water_backend = use_water_backend
gemm = wave_compile(options, gemm, schedule)

a = device_randn(shape[0], shape[2], dtype=torch.float16)
Expand All @@ -3420,7 +3426,9 @@ def test_gfx1250_tbuf_gemm_codegen(use_water_backend: bool, tmp_path: Path):
compile_to_mlir=False,
)

schedule = get_gfx1250_tbuf_gemm_schedule()
schedule = get_gfx1250_tbuf_gemm_schedule(
insert_tensor_waitcount=not use_water_backend
)
options.target = "gfx1250"
options.dump_intermediates = tmp_path
options.use_water_backend = use_water_backend
Expand All @@ -3441,7 +3449,6 @@ def test_gfx1250_tbuf_gemm_codegen(use_water_backend: bool, tmp_path: Path):
"s_wait_xcnt 0x0",
"s_wait_kmcnt 0x0",
"s_wait_tensorcnt 0x1",
"s_wait_tensorcnt 0x1",
"s_wait_dscnt 0x0",
"s_wait_tensorcnt 0x1",
"s_wait_dscnt 0xe",
Expand Down
17 changes: 17 additions & 0 deletions water/include/water/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,21 @@ def WaterMemrefDecompositionPass : Pass<"water-memref-decomposition"> {
];
}

def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> {
let summary = "Insert wait instructions for asynchronous memory operations";
let description = [{
This pass analyzes asynchronous memory operations and inserts appropriate
wait/synchronization instructions to ensure memory operations complete
before their results are used.

The pass tracks dependencies between async memory operations,
maintaining scoreboards to determine when waits are necessary. It handles:
- Read-after-write (RAW) dependencies
- Write-after-read (WAR) dependencies
}];
let dependentDialects = [
"::mlir::amdgpu::AMDGPUDialect",
];
}

#endif // WATER_PASSES
1 change: 1 addition & 0 deletions water/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRWaterTransforms
GPUToGPURuntime.cpp
MemrefDecomposition.cpp
SLPVectorizer.cpp
WaterInsertWaitcnt.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/water
Expand Down
Loading