From 78d392756ce17112d2cb8b3f9c793c3fffd27c63 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Dec 2025 13:27:19 +0100 Subject: [PATCH 01/38] pass stub Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 76 ++ water/lib/Transforms/CMakeLists.txt | 4 + water/lib/Transforms/WaterInsertWaitcnt.cpp | 870 +++++++++++++ water/lib/Transforms/WaterLowerMemoryOps.cpp | 1121 +++++++++++++++++ .../Transforms/WaterMaterializeRegCopy.cpp | 268 ++++ water/lib/Transforms/WaterNumberRegisters.cpp | 109 ++ water/test/Transforms/insert-waitcnt.mlir | 624 +++++++++ water/test/Transforms/lower-memory-ops.mlir | 557 ++++++++ .../test/Transforms/materialize-reg-copy.mlir | 165 +++ .../Transforms/number-registers-error.mlir | 7 + water/test/Transforms/number-registers.mlir | 80 ++ 11 files changed, 3881 insertions(+) create mode 100644 water/lib/Transforms/WaterInsertWaitcnt.cpp create mode 100644 water/lib/Transforms/WaterLowerMemoryOps.cpp create mode 100644 water/lib/Transforms/WaterMaterializeRegCopy.cpp create mode 100644 water/lib/Transforms/WaterNumberRegisters.cpp create mode 100644 water/test/Transforms/insert-waitcnt.mlir create mode 100644 water/test/Transforms/lower-memory-ops.mlir create mode 100644 water/test/Transforms/materialize-reg-copy.mlir create mode 100644 water/test/Transforms/number-registers-error.mlir create mode 100644 water/test/Transforms/number-registers.mlir diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index b5d9f704ab..832771320f 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -169,4 +169,80 @@ def WaterMemrefDecompositionPass : Pass<"water-memref-decomposition"> { ]; } +def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { + let summary = "Insert wait instructions for asynchronous memory operations"; + let description = [{ + This pass analyzes asynchronous memory operations and inserts appropriate + wait/synchronization instructions to ensure memory operations complete + before their results are used. + + The pass tracks dependencies between memory operations and register uses, + maintaining scoreboards to determine when waits are necessary. It handles: + - Read-after-write (RAW) dependencies + - Write-after-write (WAW) dependencies + - Write-after-read (WAR) dependencies + + This is analogous to LLVM's SIInsertWaitcnts pass but operates at the + MLIR level for AMDGPU dialect operations. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + ]; +} + +def WaterLowerMemoryOps : InterfacePass<"water-lower-memory-ops", "::mlir::FunctionOpInterface"> { + let summary = "Lower high-level memory operations to AMDGPU dialect"; + let description = [{ + This pass lowers high-level memory operations (vector.load, vector.store, + memref operations) to AMDGPU-specific memory operations (buffer loads/stores, + LDS operations, etc.). + + This lowering prepares the IR for subsequent waitcnt insertion and + final code generation. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + "::mlir::gpu::GPUDialect", + "::mlir::LLVM::LLVMDialect", + "::mlir::memref::MemRefDialect", + "::mlir::ROCDL::ROCDLDialect", + "::mlir::vector::VectorDialect", + ]; + let options = [ + Option<"chipset", "chipset", "std::string", [{""}], + "Target chipset (e.g., gfx942, gfx1100)"> + ]; +} + +def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { + let summary = "Materialize register copies for loads"; + let description = [{ + This pass materializes explicit register copies by transforming load + operations to route through a temporary buffer in the virtual register + memory space (memspace 128). For each load: + 1. Creates a subview of the source memref at the load indices + 2. Allocates a temporary buffer in memory space 128 (virtual register space) + 3. Copies from the subview to the temporary register buffer + 4. Loads from the temporary register buffer + + This transformation makes register traffic explicit in the IR, enabling + better analysis and optimization of register usage patterns. + }]; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::memref::MemRefDialect", + ]; +} + +def WaterNumberRegisters : InterfacePass<"water-number-registers", "::mlir::FunctionOpInterface"> { + let summary = "Assign physical registers to register space allocas"; + let description = [{ + This pass performs register allocation by assigning physical register numbers + to memref.alloca operations in memory space 128 (virtual register space). + }]; + let dependentDialects = [ + "::mlir::memref::MemRefDialect", + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index a26f9117cf..4c14676fb7 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -8,6 +8,10 @@ add_mlir_dialect_library(MLIRWaterTransforms GPUToGPURuntime.cpp MemrefDecomposition.cpp SLPVectorizer.cpp + WaterInsertWaitcnt.cpp + WaterLowerMemoryOps.cpp + WaterMaterializeRegCopy.cpp + WaterNumberRegisters.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp new file mode 100644 index 0000000000..8e9eb5fbb9 --- /dev/null +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -0,0 +1,870 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DenseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "water-insert-waitcnt" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERINSERTWAITCNT +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { +static bool isBarrier(Operation *op) { + return isa(op) || isa(op); +} + +static bool isRegisterAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getInt() == 128; +} + +static bool isWorkgroupAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getValue() == gpu::AddressSpace::Workgroup; +} + +static bool isWorkgroupAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return isWorkgroupAddressSpace(memrefType); +} + +static bool isGlobalAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return !isWorkgroupAddressSpace(memrefType) && + !isRegisterAddressSpace(memrefType); +} + +/// Try to propagate view operations to the base memref. +static std::optional propagateViewOps(Value value) { + while (auto view = value.getDefiningOp()) + value = view.getViewSource(); + + return value; +} + +/// Check if the operation is a load operation and return the base memref. +static std::optional isLoadOp(Operation *op) { + // TODO: replace with the interface when available. + if (auto load = dyn_cast(op)) + return propagateViewOps(load.getBase()); + if (auto load = dyn_cast(op)) + return propagateViewOps(load.getMemRef()); + if (auto copy = dyn_cast(op)) + return propagateViewOps(copy.getSource()); + if (auto gather = dyn_cast(op)) + return propagateViewOps(gather.getSrc()); + + return std::nullopt; +} + +/// Check if the operation is a store operation and return the base memref. +static std::optional isStoreOp(Operation *op) { + // TODO: replace with the interface when available. + if (auto store = dyn_cast(op)) + return propagateViewOps(store.getBase()); + if (auto store = dyn_cast(op)) + return propagateViewOps(store.getMemRef()); + if (auto copy = dyn_cast(op)) + return propagateViewOps(copy.getTarget()); + if (auto gather = dyn_cast(op)) + return propagateViewOps(gather.getDst()); + + return std::nullopt; +} + +template +static raw_ostream &print_range(raw_ostream &os, T &&range) { + llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); + return os; +} + +/// Shared pending operations list for structural sharing +struct PendingOperations { + using TokenContainer = SmallVector; + + PendingOperations() = default; + PendingOperations(SmallVector &&ops, + SmallVector &&opsTokens) + : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} + + TokenContainer &addOp(Operation *op) { + // Failsafe to prevent infinite list growth. + if (size() >= 256) + llvm::report_fatal_error("Pending operations list is too long"); + + if (!ops.empty() && isBarrier(op) && isBarrier(ops.back())) + return opsTokens.back(); + + ops.push_back(op); + auto &back = opsTokens.emplace_back(); + if (auto memref = isStoreOp(op)) + back.push_back(*memref); + + if (auto memref = isLoadOp(op)) + back.push_back(*memref); + + return back; + } + + size_t size() const { return ops.size(); } + bool empty() const { return ops.empty(); } + + auto opsAndTokens() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(ops, opsTokens); + } + + auto opsAndTokensReverse() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(llvm::reverse(ops), llvm::reverse(opsTokens)); + } + + bool hasSameTail(const PendingOperations &other) const { + for (const auto &[op1, op2, tok1, tok2] : + llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops), + llvm::reverse(opsTokens), llvm::reverse(other.opsTokens))) { + if (op1 != op2) + return false; + if (tok1 != tok2) + return false; + } + return true; + } + + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (TokenContainer &tokens : opsTokens) { + TokenContainer newTok; + for (Value tok : tokens) + updateFunc(tok, newTok); + + tokens = std::move(newTok); + } + } + + void print(raw_ostream &os) const { + os << "PendingOperations: ops=["; + llvm::interleaveComma(opsAndTokens(), os, [&](const auto &opAndTok) { + os << *std::get<0>(opAndTok) << "|"; + print_range(os, std::get<1>(opAndTok)); + }); + os << "]"; + } + + bool operator==(const PendingOperations &other) const { + return ops == other.ops && opsTokens == other.opsTokens; + } + + bool operator!=(const PendingOperations &other) const { + return !(*this == other); + } + + SmallVector ops; + SmallVector opsTokens; +}; + +/// Waitcnt requirement for synchronization +struct WaitcntRequirement { + std::optional load_cnt; + std::optional ds_cnt; + + WaitcntRequirement() = default; + + WaitcntRequirement(amdgpu::MemoryCounterWaitOp waitOp) { + if (auto loadCnt = waitOp.getLoadAttr()) + load_cnt = loadCnt.getInt(); + if (auto dsCnt = waitOp.getDsAttr()) + ds_cnt = dsCnt.getInt(); + } + + bool hasRequirement() const { + return load_cnt.has_value() || ds_cnt.has_value(); + } + + /// Merge with another requirement (take minimum for conservative join) + /// Returns true if this requirement changed + bool merge(const WaitcntRequirement &other) { + bool changed = false; + + // Take minimum of each counter (lower value = more restrictive) + if (other.load_cnt.has_value()) { + if (!load_cnt.has_value() || *other.load_cnt < *load_cnt) { + load_cnt = other.load_cnt; + changed = true; + } + } + if (other.ds_cnt.has_value()) { + if (!ds_cnt.has_value() || *other.ds_cnt < *ds_cnt) { + ds_cnt = other.ds_cnt; + changed = true; + } + } + + return changed; + } + + std::optional getLoadCnt() const { return load_cnt; } + std::optional getStoreCnt() const { return std::nullopt; } + std::optional getDsCnt() const { return ds_cnt; } + + bool isSameCounterType(const WaitcntRequirement &other) const { + return load_cnt.has_value() == other.load_cnt.has_value() || + ds_cnt.has_value() == other.ds_cnt.has_value(); + } + + static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { + WaitcntRequirement req; + std::optional loadBase = isLoadOp(op); + std::optional storeBase = isStoreOp(op); + if (isWorkgroupAddressSpace(loadBase) || + isWorkgroupAddressSpace(storeBase)) { + req.ds_cnt = zero ? 0 : 1; + } else if (isGlobalAddressSpace(loadBase) || + isGlobalAddressSpace(storeBase)) { + req.load_cnt = zero ? 0 : 1; + } + return req; + } + + WaitcntRequirement operator+(const WaitcntRequirement &other) const { + WaitcntRequirement result; + if (load_cnt || other.load_cnt) + result.load_cnt = load_cnt.value_or(0) + other.load_cnt.value_or(0); + if (ds_cnt || other.ds_cnt) + result.ds_cnt = ds_cnt.value_or(0) + other.ds_cnt.value_or(0); + return result; + } + + bool operator>(const WaitcntRequirement &other) const { + if (load_cnt && other.load_cnt && *load_cnt > *other.load_cnt) + return true; + if (ds_cnt && other.ds_cnt && *ds_cnt > *other.ds_cnt) + return true; + return false; + } + operator bool() const { return hasRequirement(); } + + void print(raw_ostream &os) const { + os << "WaitcntRequirement: load_cnt=" << load_cnt << " ds_cnt=" << ds_cnt; + } +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const WaitcntRequirement &result) { + result.print(os); + return os; +} + +static bool mayAlias(Value lhs, Value rhs, ArrayRef tokens) { + if (isWorkgroupAddressSpace(cast(lhs.getType())) != + isWorkgroupAddressSpace(cast(rhs.getType()))) + return false; + + return llvm::is_contained(tokens, lhs); +} + +/// Lattice state tracking pending asynchronous operations +class WaitcntState : public AbstractDenseLattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WaitcntState) + + using AbstractDenseLattice::AbstractDenseLattice; + + ChangeResult join(const AbstractDenseLattice &rhs) override { + const auto &rhsState = static_cast(rhs); + bool changed = false; + + SmallVector, 4> toAppend; + // Check if any pending operations has the same subset of operations as the + // rhs and take the longer one. + for (auto &rhsPendingOps : rhsState.pendingOpsLists) { + bool found = false; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->hasSameTail(*rhsPendingOps)) { + if (rhsPendingOps->size() > pendingOps->size()) { + pendingOps = rhsPendingOps; + changed = true; + } + found = true; + break; + } + } + if (!found) + toAppend.push_back(rhsPendingOps); + } + + // If there are any pending operations that don't have the same subset of + // operations as the rhs, append them to the pending operations lists. + if (!toAppend.empty()) { + pendingOpsLists.append(toAppend); + changed = true; + } + + if (changed) + resetPendingOpsSet(); + + // Merge requirements (take minimum for conservative join) + if (requirement.merge(rhsState.requirement)) + changed = true; + + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + + ChangeResult merge(const WaitcntState &rhs) { + bool changed = false; + + if (pendingOpsLists.size() != rhs.pendingOpsLists.size()) { + changed = true; + } else { + for (auto [listSrc, listDst] : + llvm::zip(pendingOpsLists, rhs.pendingOpsLists)) { + if (*listSrc != *listDst) { + changed = true; + break; + } + } + } + + if (changed) { + pendingOpsLists = rhs.pendingOpsLists; + resetPendingOpsSet(); + } + + if (requirement.merge(rhs.requirement)) + changed = true; + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + + void print(raw_ostream &os) const override { + os << "WaitcntState: pending ops ["; + for (auto &pendingOps : pendingOpsLists) { + os << "\n ["; + pendingOps->print(os); + os << "]"; + } + os << "\n ], requirement: " << requirement; + } + + void addPendingOp(Operation *op) { + if (pendingOpsLists.empty()) { + pendingOpsLists.push_back(std::make_shared()); + } else { + cow(); + } + for (auto &pendingOps : pendingOpsLists) { + auto &tokens = pendingOps->addOp(op); + for (Value token : tokens) + pendingOpsTokens.insert(token); + } + + pendingOpsSet.insert(op); + } + + /// Initialize to empty state + ChangeResult reset() { + if (pendingOpsLists.empty() && !requirement.hasRequirement()) + return ChangeResult::NoChange; + + pendingOpsLists.clear(); + requirement = {}; + resetPendingOpsSet(); + return ChangeResult::Change; + } + + /// Set the required waitcnt values + void setRequirement(const WaitcntRequirement &req) { + requirement = req; + for (auto &pendingOps : pendingOpsLists) { + SmallVector newPending; + SmallVector newPendingTokens; + WaitcntRequirement runningRequirement; + for (const auto &[op, tok] : llvm::reverse(pendingOps->opsAndTokens())) { + WaitcntRequirement opReq = + WaitcntRequirement::getOperationRequirement(op, false); + runningRequirement = runningRequirement + opReq; + if (runningRequirement > requirement) + continue; + + newPending.push_back(op); + newPendingTokens.push_back(tok); + } + if (newPending.size() == pendingOps->size()) + continue; + + std::reverse(newPending.begin(), newPending.end()); + std::reverse(newPendingTokens.begin(), newPendingTokens.end()); + pendingOps = std::make_shared( + std::move(newPending), std::move(newPendingTokens)); + } + + // Remove empty lists + pendingOpsLists.erase(std::remove_if(pendingOpsLists.begin(), + pendingOpsLists.end(), + [](const auto &pendingOps) { + return pendingOps->empty(); + }), + pendingOpsLists.end()); + + // Merge lists with the same tail (keep the longer one) + for (size_t i = 0; i < pendingOpsLists.size(); ++i) { + for (size_t j = i + 1; j < pendingOpsLists.size();) { + if (pendingOpsLists[i]->hasSameTail(*pendingOpsLists[j])) { + if (pendingOpsLists[j]->size() > pendingOpsLists[i]->size()) { + pendingOpsLists[i] = pendingOpsLists[j]; + } + pendingOpsLists.erase(pendingOpsLists.begin() + j); + } else { + ++j; + } + } + } + + resetPendingOpsSet(); + } + + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (auto &pendingOps : pendingOpsLists) + pendingOps->updateTokens(updateFunc); + } + + void resetRequirement() { requirement = {}; } + + /// Get the required waitcnt values + const WaitcntRequirement &getRequirement() const { return requirement; } + + /// Check if there's a waitcnt requirement + bool hasRequirement() const { return requirement.hasRequirement(); } + + /// Check if a value depends on pending operations and compute required wait + WaitcntRequirement + checkSSADependency(Value val, + llvm::SmallSetVector &barriers) const { + // Check if val is produced by any pending operation + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return {}; + + if (!isPendingOp(defOp)) + return {}; + + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) + continue; + + Operation *barrier = nullptr; + + // Search from the back to find the most recent dependency + bool found = false; + auto req = WaitcntRequirement::getOperationRequirement(defOp, true); + for (Operation *op : llvm::reverse(pendingOps->ops)) { + if (op == defOp) { + found = true; + break; + } + + if (!barrier && isBarrier(op)) + barrier = op; + + auto opReq = WaitcntRequirement::getOperationRequirement(op, false); + if (!req.isSameCounterType(opReq)) + continue; + + req = req + opReq; + } + + if (found) { + result.merge(req); + if (barrier) + barriers.insert(barrier); + } + } + + return result; + } + + /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait + WaitcntRequirement + checkMemoryDependency(Operation *op, + llvm::SmallSetVector &barriers) const { + auto checkMemref = [&](Value memref, bool isCurrentLoad, + bool isCurrentStore) -> WaitcntRequirement { + WaitcntRequirement result; + if (!isPendingOp(memref)) + return result; + + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) + continue; + + Operation *barrier = nullptr; + + // Search from the back to find the most recent dependency + for (const auto &[pendingOpVar, pendingTokensVar] : + pendingOps->opsAndTokensReverse()) { + + if (!barrier && isBarrier(pendingOpVar)) + barrier = pendingOpVar; + + // We canot capture structured bindings into lambda, thanks C++. + auto &pendingTokens = pendingTokensVar; + auto &pendingOp = pendingOpVar; + auto checkPendingMemref = + [&](Value pendingMemref, bool isPendingLoad, + bool isPendingStore) -> WaitcntRequirement { + WaitcntRequirement pendingResult; + if (!mayAlias(memref, pendingMemref, pendingTokens)) + return pendingResult; + + // Check for dependencies: + // RAW: current load after pending store + // WAR: current store after pending load + // WAW: current store after pending store + bool hasRAW = isCurrentLoad && isPendingStore; + bool hasWAR = isCurrentStore && isPendingLoad; + bool hasWAW = isCurrentStore && isPendingStore; + + if (hasRAW || hasWAR || hasWAW) { + // Found dependency - compute requirement by counting forward from + // here + auto it = llvm::find(pendingOps->ops, pendingOp); + auto req = + WaitcntRequirement::getOperationRequirement(pendingOp, true); + for (Operation *countOp : + llvm::make_range(std::next(it), pendingOps->ops.end())) { + auto opReq = + WaitcntRequirement::getOperationRequirement(countOp, false); + if (!req.isSameCounterType(opReq)) + continue; + req = req + opReq; + } + pendingResult.merge(req); + } + if (pendingResult.hasRequirement() && barrier) + barriers.insert(barrier); + + return pendingResult; + }; + if (auto loadBase = isLoadOp(pendingOp)) + result.merge(checkPendingMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(pendingOp)) + result.merge(checkPendingMemref(*storeBase, false, true)); + } + } + + return result; + }; + // TODO: atomics will have both load and store flags set + WaitcntRequirement result; + if (auto loadBase = isLoadOp(op)) + result.merge(checkMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(op)) + result.merge(checkMemref(*storeBase, false, true)); + return result; + } + +private: + /// Pending asynchronous operations + SmallVector, 4> pendingOpsLists; + + /// Required waitcnt after this state + WaitcntRequirement requirement; + + mutable llvm::SmallDenseSet pendingOpsSet; + mutable llvm::SmallDenseSet pendingOpsTokens; + + void cow() { + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps.use_count() > 1) { + auto newPending = std::make_shared(); + if (pendingOps) + *newPending = *pendingOps; + pendingOps = std::move(newPending); + } + } + } + + bool isPendingOp(llvm::PointerUnion opOrVal) const { + if (pendingOpsLists.empty()) + return false; + + // Build the set of pending operations lazily + bool found = false; + if (pendingOpsSet.empty()) { + assert(pendingOpsTokens.empty() && "pendingOpsTokens must be empty"); + Operation *op = dyn_cast(opOrVal); + Value val = dyn_cast(opOrVal); + for (const auto &pendingOps : pendingOpsLists) { + for (const auto &[pendingOp, pendingTokens] : + pendingOps->opsAndTokens()) { + if (pendingOp == op) + found = true; + + pendingOpsSet.insert(pendingOp); + for (Value token : pendingTokens) { + if (token == val) + found = true; + + pendingOpsTokens.insert(token); + } + } + } + } + + if (found) + return true; + + return isa(opOrVal) + ? pendingOpsSet.contains(cast(opOrVal)) + : pendingOpsTokens.contains(cast(opOrVal)); + } + + void resetPendingOpsSet() { + pendingOpsSet.clear(); + pendingOpsTokens.clear(); + } +}; + +static RegionSuccessor getRegionResults(ArrayRef successors, + Region *region) { + for (const auto &successor : successors) { + if (successor.getSuccessor() == region) + return successor; + } + llvm_unreachable("Region not found, malformed SCF op?"); +} + +/// Dense forward dataflow analysis for waitcnt insertion +class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { +public: + explicit WaitcntAnalysis(DataFlowSolver &solver) + : DenseForwardDataFlowAnalysis(solver) {} + + void setToEntryState(WaitcntState *lattice) override { + propagateIfChanged(lattice, lattice->reset()); + } + + LogicalResult visitOperation(Operation *op, const WaitcntState &before, + WaitcntState *after) override { + LDBG() << "Visiting: " << *op; + LDBG() << " Before: " << before; + + // Start with the state before this operation + WaitcntState newState = before; + + if (isBarrier(op)) { + LDBG() << " Barrier: " << *op; + newState.addPendingOp(op); + LDBG() << " New state: " << newState; + propagateIfChanged(after, after->join(newState)); + return success(); + } + + llvm::SmallSetVector barriers; + + // Check if any operands depend on pending operations (value dependency) + WaitcntRequirement opRequirement = after->getRequirement(); + for (Value operand : op->getOperands()) { + if (auto req = before.checkSSADependency(operand, barriers)) { + // Merge this requirement (take minimum for conservative wait) + opRequirement.merge(req); + } + } + + // Check for memory dependencies (RAW, WAR, WAW) + if (auto memReq = before.checkMemoryDependency(op, barriers)) { + LDBG() << " Memory dependency: " << memReq; + opRequirement.merge(memReq); + } else { + LDBG() << " No memory dependency"; + } + + if (opRequirement.hasRequirement() && !barriers.empty()) { + // newState.setRequirement(opRequirement); + LDBG() << " Barriers found, requirement: " << opRequirement; + for (Operation *barrier : barriers) { + LDBG() << " " << *barrier; + WaitcntState *beforeState = + getOrCreate(getProgramPointBefore(barrier)); + WaitcntState *afterState = + getOrCreate(getProgramPointAfter(barrier)); + WaitcntState newBarrierState = *beforeState; + newBarrierState.setRequirement(opRequirement); + propagateIfChanged(afterState, afterState->merge(newBarrierState)); + } + return success(); + } + + // Check if this is an existing memory_counter_wait operation + if (auto waitOp = dyn_cast(op)) { + LDBG() << " Existing waitcnt operation: " << *waitOp; + opRequirement.merge(WaitcntRequirement(waitOp)); + } + + // Set the requirement for this operation + if (opRequirement.hasRequirement()) { + newState.setRequirement(opRequirement); + LDBG() << " Operation requirement: " << opRequirement; + } else { + newState.resetRequirement(); + LDBG() << " No operation requirement"; + } + + // Check if this is an async memory operation (vector load/store) + if (WaitcntRequirement::getOperationRequirement(op, false) + .hasRequirement()) { + // Add this operation to the pending list + newState.addPendingOp(op); + } + + auto changed = after->merge(newState); + if (changed == ChangeResult::Change) { + LDBG() << " New state: " << newState; + } else { + LDBG() << " No change"; + } + propagateIfChanged(after, changed); + return success(); + } + + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const WaitcntState &before, + WaitcntState *after) override { + LDBG() << "Visiting region branch control flow transfer: " << *branch; + LDBG() << " Region from: " << regionFrom; + LDBG() << " Region to: " << regionTo; + LDBG() << " Before: " << before; + LDBG() << " After: " << *after; + + SmallVector successors; + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + + auto destSuccessor = [&]() -> RegionSuccessor { + if (regionTo) { + Region ®ion = branch->getRegions()[*regionTo]; + return getRegionResults(successors, ®ion); + } else { + return getRegionResults(successors, nullptr); + } + }(); + // Dest values are either nested block args or branch op results. + ValueRange destValues = destSuccessor.getSuccessorInputs(); + + // Map from input values to dest values. + llvm::SmallDenseMap valuesMapping; + if (regionFrom) { + Region ®ion = branch->getRegions()[*regionFrom]; + for (Block &block : region) { + auto term = + dyn_cast(block.getTerminator()); + if (!term) + continue; + + ValueRange source = + term.getMutableSuccessorOperands(destSuccessor).getAsOperandRange(); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; + } + } else { + ValueRange source = branch.getEntrySuccessorOperands(destSuccessor); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; + } + + DominanceInfo dom; + + WaitcntState newState = before; + auto tokenUpdateFunc = [&](Value value, SmallVectorImpl &newTokens) { + // Keep the token if it dominates current op as user can use it directly. + if (dom.properlyDominates(value, branch)) + newTokens.push_back(value); + + // Add token propagated through region control flow. + if (Value mappedValue = valuesMapping.lookup(value)) + if (newTokens.empty() || newTokens.back() != mappedValue) + newTokens.push_back(mappedValue); + }; + newState.updateTokens(tokenUpdateFunc); + + LDBG() << " New state: " << newState; + + propagateIfChanged(after, after->join(newState)); + } +}; + +/// Pass that inserts wait/synchronization instructions for asynchronous +/// memory operations. This is analogous to LLVM's SIInsertWaitcnts pass. +class WaterInsertWaitcntPass + : public water::impl::WaterInsertWaitcntBase { +public: + void runOnOperation() override { + LDBG() << "Running WaterInsertWaitcntPass"; + Operation *op = getOperation(); + + DataFlowSolver solver; + loadBaselineAnalyses(solver); + solver.load(); + + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // Insert waitcnt operations based on analysis results + IRRewriter rewriter(&getContext()); + op->walk([&](Operation *operation) { + const WaitcntState *state = solver.lookupState( + solver.getProgramPointAfter(operation)); + if (!state || !state->hasRequirement()) + return; + + const WaitcntRequirement &req = state->getRequirement(); + + auto getAttr = [&](std::optional cnt) -> IntegerAttr { + if (!cnt.has_value()) + return nullptr; + return rewriter.getI32IntegerAttr(*cnt); + }; + + // Insert wait operation before the current operation. + // If the current operation is already a memory_counter_wait operation + // they will be merged later. + rewriter.setInsertionPoint(operation); + amdgpu::MemoryCounterWaitOp::create( + rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), + getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr, + nullptr); + }); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp new file mode 100644 index 0000000000..3fd4e6ad61 --- /dev/null +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -0,0 +1,1121 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERLOWERMEMORYOPS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +static unsigned getBitwidth(ShapedType type) { + assert(type.hasStaticShape() && "Shaped type must have static shape"); + return type.getNumElements() * type.getElementTypeBitWidth(); +} + +static unsigned getBitwidth(Type type) { + if (auto shaped = dyn_cast(type)) + return getBitwidth(shaped); + + return type.getIntOrFloatBitWidth(); +} + +static std::string getVGPRRange(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount) { + assert(vgprCount > 0 && "VGPR count must be greater than 0"); + unsigned start = vgprOffset + vgprNum; + if (vgprCount == 1) { + return ("v" + llvm::Twine(start)).str(); + } else { + unsigned end = start + vgprCount - 1; + return ("v[" + llvm::Twine(start) + ":" + llvm::Twine(end) + "]").str(); + } +} + +static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, bool isOutput) { + return (llvm::Twine(isOutput ? "=" : "") + "{" + + getVGPRRange(vgprOffset, vgprNum, vgprCount) + "}") + .str(); +} + +static FailureOr getLoadSizeSuffixRDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("u8"); + case 16: + return StringRef("u16"); + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } +} + +static FailureOr getStoreSizeSuffixRDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } +} + +static FailureOr getLoadSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("ubyte"); + case 16: + return StringRef("ushort"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getStoreSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("byte"); + case 16: + return StringRef("short"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getBufferLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getBufferStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getDSLoadSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getLoadSizeSuffixRDNA(bitWidth); +} + +static FailureOr getDSStoreSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getStoreSizeSuffixRDNA(bitWidth); +} + +/// Create an LLVM inline assembly operation with standard attributes +static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, + TypeRange resultTypes, + ValueRange operands, StringRef asmStr, + StringRef constraints, + bool hasSideEffects) { + return LLVM::InlineAsmOp::create( + rewriter, loc, resultTypes, operands, asmStr, constraints, hasSideEffects, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); +} + +/// Detect if chipset is RDNA vs CDNA architecture +static bool isRDNA(const amdgpu::Chipset &chipset) { + return chipset.majorVersion != 9; +} + +static Operation *propagateExtract(Operation *op) { + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + return nullptr; +} + +static unsigned checkHazards(Operation *currentOp, Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return 0; + + while (auto nextOp = propagateExtract(op)) + op = nextOp; + + if (op->getBlock() != currentOp->getBlock()) + return 0; + + if (!isa(op)) + return 0; + + while (op != currentOp) { + if (isa(op) && + cast(op).getIntrin() == "llvm.amdgcn.s.nop") + return 0; + op = op->getNextNode(); + } + + return 5; // HACK for now +} + +static void handleHazards(IRRewriter &rewriter, Location loc, Operation *op, + Value value) { + unsigned hazard = checkHazards(op, value); + if (hazard > 0) { + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + Value nopCount = + arith::ConstantIntOp::create(rewriter, loc, hazard - 1, 16); + StringAttr intrin = rewriter.getStringAttr("llvm.amdgcn.s.nop"); + LLVM::CallIntrinsicOp::create(rewriter, loc, {}, intrin, nopCount); + } +} + +/// Compute byte offset as iX for a memref access with indices +template +static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + // Extract strided metadata to get offset and strides + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value offset = metadataOp.getOffset(); + + // Compute linear index from multidimensional indices + Value linearIndex = offset; + for (auto i : llvm::seq(0, indices.size())) { + Value stride = metadataOp.getStrides()[i]; + Value indexTimesStride = arith::MulIOp::create( + rewriter, loc, indices[i], stride, arith::IntegerOverflowFlags::nsw); + linearIndex = + arith::AddIOp::create(rewriter, loc, linearIndex, indexTimesStride, + arith::IntegerOverflowFlags::nsw); + } + + // Convert linear index to byte offset + unsigned elementBytes = elementBitWidth / 8; + Value elementSize = + arith::ConstantIndexOp::create(rewriter, loc, elementBytes); + Value byteOffset = + arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, + arith::IntegerOverflowFlags::nsw); + + Type indexType = IntegerType::get(rewriter.getContext(), Bits); + return arith::IndexCastOp::create(rewriter, loc, indexType, byteOffset); +} + +/// Compute the final address for a memref access with indices (for global +/// operations) +template +static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), MemSpace); + auto intType = rewriter.getIntegerType(Bits); + + // Extract base pointer + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value basePtr = metadataOp.getBaseBuffer(); + + // Convert base pointer to i64 + Value basePtrInt = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); + basePtrInt = arith::IndexCastOp::create(rewriter, loc, intType, basePtrInt); + + // Compute byte offset + Value byteOffsetI64 = computeMemrefByteOffset(rewriter, loc, memref, + indices, elementBitWidth); + + // Add byte offset to base pointer + Value finalAddr = + arith::AddIOp::create(rewriter, loc, basePtrInt, byteOffsetI64, + arith::IntegerOverflowFlags::nsw); + return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); +} + +/// Extract buffer descriptor and base offset from a fat_raw_buffer memref +/// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) +/// Returns: {resource descriptor (i128), base offset (i32)} +static std::pair +extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { + // Create proper memref descriptor struct type: {ptr, ptr, offset, + // sizes[rank], strides[rank]} + auto memrefType = cast(memref.getType()); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto arrayType = LLVM::LLVMArrayType::get(i64Type, memrefType.getRank()); + Type descriptorFields[] = {ptrType, ptrType, i64Type, arrayType, arrayType}; + + auto memrefDescType = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), descriptorFields); + + Value memrefDescVal = + UnrealizedConversionCastOp::create(rewriter, loc, memrefDescType, memref) + .getResult(0); + + MemRefDescriptor memrefDesc(memrefDescVal); + Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); + + // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 + // offset} + auto i160Type = IntegerType::get(rewriter.getContext(), 160); + Value fullDesc = LLVM::PtrToIntOp::create(rewriter, loc, i160Type, bufferPtr); + + // Extract lower 32 bits for base offset + Value baseOffset = arith::TruncIOp::create(rewriter, loc, i32Type, fullDesc); + + // Extract upper 128 bits for resource descriptor + auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); + Value rsrcBits160 = arith::ShRUIOp::create(rewriter, loc, fullDesc, c32); + auto i128Type = IntegerType::get(rewriter.getContext(), 128); + Value rsrcBits = + arith::TruncIOp::create(rewriter, loc, i128Type, rsrcBits160); + + return {rsrcBits, baseOffset}; +} + +/// Helper to get memref, result type, and bit width from load operation +template +static std::tuple getLoadOpInfo(LoadOpTy loadOp) { + if constexpr (std::is_same_v) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = getBitwidth(vectorType); + return {loadOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = loadOp.getResult().getType(); + unsigned bitWidth = getBitwidth(elementType); + return {loadOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Helper to get memref, value type, and bit width from store operation +template +static std::tuple getStoreOpInfo(StoreOpTy storeOp) { + if constexpr (std::is_same_v) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = getBitwidth(vectorType); + return {storeOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = storeOp.getValueToStore().getType(); + unsigned bitWidth = getBitwidth(elementType); + return {storeOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Lower vector/scalar load to AMDGPU buffer load inline assembly +template +static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + // TODO: for bitwidths less than 32, we will need to truncate the value to 32 + // immediately after the load, breaking the calculated dependencies. + // For now, just let llvm handle the loading + if (bitWidth < 32) + return success(); + + FailureOr suffix = getBufferLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "=v" for output (VGPR), "v" for offset (VGPR), "s" for + // descriptor (SGPR[4]) + StringRef constraints = "=v,v,s"; + + // Compute byte offset from indices + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefByteOffset<32>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + // Create inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower vector/scalar load to LLVM inline assembly (global_load_*) +template +static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + if (bitWidth < 32) + return success(); + + FailureOr suffix = getGlobalLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + + // Build the inline assembly string: "global_load_b64 $0, $1, off" + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); + + // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) + StringRef constraints = "=v,v"; + + rewriter.setInsertionPoint(loadOp); + + // Compute the final address + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Create the inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower vector/scalar load to AMDGPU DS load inline assembly +template +static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + if (bitWidth < 32) + return success(); + + FailureOr suffix = getDSLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + // Constraints: "=v" for output (VGPR), "v" for address (VGPR) + StringRef constraints = "=v,v"; + + // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefAddress<32, 3>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Create inline assembly operation (DS operations use 32-bit addresses) + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +static Value extendToReg(Value value, IRRewriter &rewriter, Location loc) { + unsigned bitWidth = getBitwidth(value.getType()); + if (bitWidth >= 32) { + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); + return value; + } + + // Sched barrier to prevent moving the expansion before the waitcnt. + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); + + return arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), value); +} + +/// Lower vector/scalar store to AMDGPU buffer store inline assembly +template +static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getBufferStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported buffer store bit width: ") + << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build inline assembly: "buffer_store_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_store_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "v" for data (VGPR), "v" for offset (VGPR), "s" for descriptor + // (SGPR[4]) + StringRef constraints = "v,v,s"; + + // Compute byte offset from indices + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefByteOffset<32>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, TypeRange{}, + {valueToStore, finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Lower vector/scalar store to LLVM inline assembly (global_store_*) +template +static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getGlobalStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported store bit width: ") << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build the inline assembly string: "global_store_b64 $0, $1, off" + std::string asmStr = ("global_store_" + *suffix + " $0, $1, off").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + StringRef constraints = "v,v"; + + // Compute the final address + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create the inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, {}, {addr, valueToStore}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Lower vector/scalar store to AMDGPU DS store inline assembly +template +static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getDSStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build inline assembly: "ds_write_b32 $0, $1" + std::string asmStr = ("ds_write_" + *suffix + " $0, $1").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + StringRef constraints = "v,v"; + + // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefAddress<32, 3>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create inline assembly operation (no result for store, DS uses 32-bit + // addresses) + createInlineAsm(rewriter, loc, {}, {offset, valueToStore}, asmStr, + constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Check if a memref uses AMDGPU fat_raw_buffer address space +static bool usesBufferAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (!memorySpace) + return false; + + // Check for #amdgpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) + return enumAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer; + + return false; +} + +/// Check if a memref uses workgroup (LDS) address space +static bool usesWorkgroupAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (!memorySpace) + return false; + + // Check for #gpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) + return enumAttr.getValue() == gpu::AddressSpace::Workgroup; + + return false; +} + +/// Check if a memref uses register space (memspace 128) +static bool usesRegisterSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return intAttr.getInt() == 128; + + return false; +} + +/// Lower memref.copy when destination is in register space - buffer variant +static LogicalResult lowerCopyToRegBuffer(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getBufferLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset (no indices for full copy) + Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Extract buffer descriptor and base offset + auto [bufferDesc, baseOffset] = extractBufferDescriptor(rewriter, loc, src); + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v,s"; + + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - DS variant +static LogicalResult lowerCopyToRegDS(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, unsigned totalBits, + Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getDSLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset + Value offset = computeMemrefAddress<32, 3>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - global variant +static LogicalResult lowerCopyToRegGlobal(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getGlobalLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute source address + Value addr = computeMemrefAddress<64, 0>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; + + // Build inline assembly: "global_load_b128 $0, $1, off" + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space +static LogicalResult lowerCopyToReg(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = dst.getDefiningOp(); + if (!dstAlloca) + return copyOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return copyOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + // Get source type info + auto srcType = cast(src.getType()); + if (!srcType.hasStaticShape()) + return copyOp.emitError("source must have static shape"); + + unsigned totalBits = getBitwidth(srcType); + + // Get result type from destination + auto dstType = cast(dst.getType()); + if (!dstType.hasStaticShape()) + return copyOp.emitError("destination must have static shape"); + + unsigned resultBitWidth = getBitwidth(dstType); + unsigned resultNumElements = (resultBitWidth + 31) / 32; + Type resultType = + VectorType::get(resultNumElements, rewriter.getIntegerType(32)); + + // Dispatch based on source memory space + if (usesBufferAddressSpace(src)) + return lowerCopyToRegBuffer(copyOp, rewriter, isRDNAArch, vgprOffset, + vgprNum, vgprCount, totalBits, resultType); + if (usesWorkgroupAddressSpace(src)) + return lowerCopyToRegDS(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); + return lowerCopyToRegGlobal(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); +} + +/// Lower load from register space to inline assembly +template +static LogicalResult lowerLoadFromReg(LoadOpTy loadOp, IRRewriter &rewriter, + unsigned vgprOffset) { + Value memref; + if constexpr (std::is_same_v) + memref = loadOp.getBase(); + else + memref = loadOp.getMemRef(); + + // Get source alloca to find VGPR assignment + auto srcAlloca = memref.getDefiningOp(); + if (!srcAlloca) + return loadOp.emitError("source must be a memref.alloca"); + + // Get VGPR number from source alloca + auto vgprNumAttr = srcAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + srcAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return loadOp.emitError("source alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build constraint for reading from specific VGPR(s) + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true); + + // Simple v_mov to read from VGPR (compiler will optimize this away) + std::string asmStr = + "; reg_load " + getVGPRRange(vgprOffset, vgprNum, vgprCount); + + Type resultType = loadOp.getResult().getType(); + Type asmType = resultType; + unsigned bitWidth = getBitwidth(resultType); + if (bitWidth < 32) + asmType = rewriter.getIntegerType(32); + + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + + Value asmResult = createInlineAsm(rewriter, loc, asmType, {}, asmStr, + constraints, /*hasSideEffects=*/false) + .getResult(0); + + if (bitWidth < 32) { + auto narrowType = rewriter.getIntegerType(bitWidth); + asmResult = arith::TruncIOp::create(rewriter, loc, narrowType, asmResult); + asmResult = LLVM::BitcastOp::create(rewriter, loc, resultType, asmResult); + } + + rewriter.replaceOp(loadOp, asmResult); + return success(); +} + +/// Lower store to register space to inline assembly +template +static LogicalResult lowerStoreToReg(StoreOpTy storeOp, IRRewriter &rewriter, + unsigned vgprOffset) { + Value memref; + if constexpr (std::is_same_v) + memref = storeOp.getBase(); + else + memref = storeOp.getMemRef(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = memref.getDefiningOp(); + if (!dstAlloca) + return storeOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return storeOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + + // Build constraint for writing to specific VGPR(s) + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",0"; + + // v_mov to write to VGPR (input constraint 0 ties to output) + std::string asmStr = + "; reg_store " + getVGPRRange(vgprOffset, vgprNum, vgprCount); + + Value valueToStore = storeOp.getValueToStore(); + unsigned bitWidth = getBitwidth(valueToStore.getType()); + if (bitWidth < 32) { + auto intType = rewriter.getIntegerType(bitWidth); + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + auto i32Type = rewriter.getIntegerType(32); + valueToStore = arith::ExtUIOp::create(rewriter, loc, i32Type, valueToStore); + } + + createInlineAsm(rewriter, loc, valueToStore.getType(), valueToStore, asmStr, + constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +class WaterLowerMemoryOpsPass + : public water::impl::WaterLowerMemoryOpsBase { +public: + using Base::Base; + + void runOnOperation() override { + auto func = getOperation(); + auto chip = amdgpu::Chipset::parse(chipset); + if (failed(chip)) { + func->emitError("invalid chipset: ") << chipset; + return signalPassFailure(); + } + + MLIRContext *ctx = &getContext(); + + unsigned totalVGPRs = + chip->majorVersion >= 12 && chip->minorVersion >= 5 ? 1024 : 256; + + // Check if function has VGPR allocation and insert inline asm directive. + auto vgprAttr = func->getAttrOfType("water.total_vgprs"); + unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; + unsigned vgprStart = totalVGPRs - vgprCount; + + if (vgprCount > 0) { + // Add amdgpu-num-vgpr to passthrough attribute list + auto vgprStartAttr = StringAttr::get(ctx, std::to_string(vgprStart)); + auto nameAttr = StringAttr::get(ctx, "amdgpu-num-vgpr"); + + Attribute passthroughAttr; + // Get existing passthrough or create new one + if (auto existingPassthrough = + func->getAttrOfType("passthrough")) { + SmallVector attrs(existingPassthrough.begin(), + existingPassthrough.end()); + attrs.push_back(ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})); + passthroughAttr = ArrayAttr::get(ctx, attrs); + } else { + passthroughAttr = ArrayAttr::get( + ctx, {ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})}); + } + func->setAttr("passthrough", passthroughAttr); + } + + // Insert inline assembly at the beginning of the function. + Block &entryBlock = func.getFunctionBody().front(); + IRRewriter rewriter(ctx); + rewriter.setInsertionPointToStart(&entryBlock); + + if (vgprCount > 0) { + std::string asmStr = "; vgprCount = " + std::to_string(vgprCount) + + " vgprStart = " + std::to_string(vgprStart); + + createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, + /*operands=*/{}, asmStr, /*constraints=*/"", + /*hasSideEffects=*/true); + } + + // Determine if we're targeting RDNA vs CDNA architecture, CDNA has + // different buffer ops format. + bool isRDNAArch = isRDNA(*chip); + + // Helper to dispatch to the appropriate lowering function based on address + // space + auto lowerMemoryOp = [&](Value base, auto lowerRegister, auto lowerBuffer, + auto lowerWorkgroup, + auto lowerGlobal) -> LogicalResult { + if (usesRegisterSpace(base)) + return lowerRegister(); + if (usesBufferAddressSpace(base)) + return lowerBuffer(); + if (usesWorkgroupAddressSpace(base)) + return lowerWorkgroup(); + return lowerGlobal(); + }; + + auto walkFn = [&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + loadOp.getBase(), + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + storeOp.getBase(), + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + loadOp.getMemRef(), + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + storeOp.getMemRef(), + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto copyOp = dyn_cast(op)) { + // Only lower copy if destination is in register space + if (usesRegisterSpace(copyOp.getTarget())) { + if (failed(lowerCopyToReg(copyOp, rewriter, isRDNAArch, vgprStart))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + } + return WalkResult::advance(); + }; + + if (func.walk(walkFn).wasInterrupted()) + signalPassFailure(); + + // Clean up register space allocas - they should all be lowered by now + WalkResult cleanupResult = func.walk([&](memref::AllocaOp allocaOp) { + if (usesRegisterSpace(allocaOp.getMemref())) { + if (!allocaOp->use_empty()) { + allocaOp->emitError("register space alloca still has uses after " + "lowering - not all operations were lowered"); + return WalkResult::interrupt(); + } + rewriter.eraseOp(allocaOp); + } + return WalkResult::advance(); + }); + + if (cleanupResult.wasInterrupted()) + signalPassFailure(); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp new file mode 100644 index 0000000000..505fd31a05 --- /dev/null +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -0,0 +1,268 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERMATERIALIZEREGCOPY +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, + unsigned rank) { + return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; +} + +static void createLoads(IRRewriter &rewriter, Location loc, Value value, + unsigned rank, Value tempAlloca, Operation *op) { + // Group uses by block and find the first use in each block + DenseMap blockToFirstUse; + for (OpOperand &use : value.getUses()) { + Operation *userOp = use.getOwner(); + Block *userBlock = userOp->getBlock(); + auto it = blockToFirstUse.find(userBlock); + if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) + blockToFirstUse[userBlock] = userOp; + } + + SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); + + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; + for (auto &[block, firstUse] : blockToFirstUse) { + rewriter.setInsertionPoint(firstUse); + Value load; + if (isa(op)) + load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + else if (auto vecLoadOp = dyn_cast(op)) + load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), + tempAlloca, zeroIndices); + blockToLoad[block] = load; + } + + // Replace uses with the appropriate load for their block + for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { + Block *userBlock = use.getOwner()->getBlock(); + use.set(blockToLoad[userBlock]); + } +} + +/// Transform a single load operation to use register space copy. +static LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { + Location loc = op->getLoc(); + rewriter.setInsertionPoint(op); + + // Extract memref, indices, and element type from either load type + Value memref, loadResult; + ValueRange indices; + Type elementType; + SmallVector loadShape; + + if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getMemRef(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + elementType = loadOp.getType(); + loadShape.resize(indices.size(), 1); + } else if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getBase(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + VectorType vecType = loadOp.getVectorType(); + elementType = vecType.getElementType(); + loadShape.resize(indices.size() - vecType.getRank(), 1); + llvm::append_range(loadShape, vecType.getShape()); + } else { + return op->emitError("unsupported load operation"); + } + + auto memrefType = cast(memref.getType()); + + // Create subview parameters + Attribute one = rewriter.getIndexAttr(1); + SmallVector offsets, sizes, strides; + for (auto [index, shape] : llvm::zip(indices, loadShape)) { + offsets.push_back(index); + sizes.push_back(rewriter.getIndexAttr(shape)); + strides.push_back(one); + } + + // Create subview of size [1, 1, ..., 1] at the load indices + auto subviewType = + memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); + auto subviewMemRefType = cast(subviewType); + Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, + memref, offsets, sizes, strides); + + // Create temporary buffer in virtual register space (memspace 128) + auto regMemSpace = rewriter.getI32IntegerAttr(128); + auto tempType = + MemRefType::get(subviewMemRefType.getShape(), elementType, + /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); + Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, + /*dynamicSizes=*/ValueRange{}, + /*alignment=*/IntegerAttr()); + + // Copy from subview to temp register buffer + memref::CopyOp::create(rewriter, loc, subview, tempAlloca); + + createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); + + // Erase the original load + rewriter.eraseOp(op); + return success(); +} + +/// Hoist allocas from loops when their loads are yielded. +static void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { + auto yieldedValues = loop.getYieldedValuesMutable(); + if (!yieldedValues) + return; + + auto loopResults = loop.getLoopResults(); + if (!loopResults) + return; + + auto loopInits = loop.getInitsMutable(); + + Block *body = loop.getBody(); + Location loc = loop.getLoc(); + + DominanceInfo dom; + + // Find yielded values that come from loads of memspace 128 allocas + for (auto [idx, yieldedValue, iterArg, init, result] : llvm::enumerate( + *yieldedValues, loop.getRegionIterArgs(), loopInits, *loopResults)) { + // Check if this is a load from memspace 128 + Operation *defOp = yieldedValue.get().getDefiningOp(); + if (!defOp) + continue; + + Value alloca; + ValueRange loadIndices; + if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getMemRef(); + loadIndices = loadOp.getIndices(); + } else if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getBase(); + loadIndices = loadOp.getIndices(); + } else { + continue; + } + + // Check all indices are zero + if (llvm::any_of(loadIndices, + [](Value idx) { return getConstantIntValue(idx) != 0; })) + continue; + + // Check if loading from memspace 128 alloca defined in this loop + auto allocaOp = alloca.getDefiningOp(); + if (!allocaOp) + continue; + if (!isInRegisterSpace(cast(alloca.getType()))) + continue; + if (!body->findAncestorOpInBlock(*allocaOp)) + continue; + + // If load dominates any use of the iter arg, we can't hoist the alloca + // because the load would be invalidated by the store. + bool dominates = false; + for (Operation *user : iterArg.getUsers()) { + if (dom.dominates(defOp, user)) { + dominates = true; + break; + } + } + if (dominates) + continue; + + // Hoist the alloca before the loop + allocaOp->moveBefore(loop); + rewriter.setInsertionPointAfter(allocaOp); + + SmallVector zeroIndices = + getZeroIndices(rewriter, loc, loadIndices.size()); + + // Store the iter arg into the alloca + if (isa(defOp)) { + memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } + + // Create iter arg loads + createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); + + // Create a load after the loop + rewriter.setInsertionPointAfter(loop); + zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); + Value loadAfterLoop; + if (isa(defOp)) { + loadAfterLoop = + memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + loadAfterLoop = vector::LoadOp::create( + rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); + } + + // Replace uses of the loop result with the new load + result.replaceAllUsesWith(loadAfterLoop); + } +} + +/// Materialize register copies by routing memref.load through temporary +/// buffers in virtual register space (memspace 128). +class WaterMaterializeRegCopyPass + : public water::impl::WaterMaterializeRegCopyBase< + WaterMaterializeRegCopyPass> { +public: + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + + // Collect all load operations to transform + SmallVector loadsToTransform; + getOperation()->walk([&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getMemRef().getType()))) + loadsToTransform.push_back(op); + } else if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getBase().getType()))) + loadsToTransform.push_back(op); + } + }); + + for (Operation *op : loadsToTransform) { + if (failed(materializeRegCopy(rewriter, op))) + return signalPassFailure(); + } + + // Hoist allocas out of loops when their loads are yielded + getOperation()->walk( + [&](scf::ForOp forOp) { hoistAllocasFromLoop(rewriter, forOp); }); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp new file mode 100644 index 0000000000..063ed7cf41 --- /dev/null +++ b/water/lib/Transforms/WaterNumberRegisters.cpp @@ -0,0 +1,109 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERNUMBERREGISTERS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +/// Calculate the number of 32-bit registers needed for a memref type. +static FailureOr getRegisterCount(MemRefType memrefType) { + // Calculate total size in bytes + unsigned elementSizeBytes = memrefType.getElementTypeBitWidth() / 8; + unsigned numElements = 1; + for (int64_t dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) + return failure(); // Can't allocate dynamic sizes in registers. + + numElements *= dim; + } + + unsigned totalBytes = elementSizeBytes * numElements; + + // Each register is 32 bits = 4 bytes + // Round up to next register boundary. + return (totalBytes + 3) / 4; +} + +/// Assign physical registers to register space allocas. +class WaterNumberRegistersPass + : public water::impl::WaterNumberRegistersBase { +public: + void runOnOperation() override { + auto func = getOperation(); + MLIRContext *ctx = &getContext(); + + SmallVector> regCounts; + + Type i32 = IntegerType::get(ctx, 32); + WalkResult result = func->walk([&](memref::AllocaOp allocaOp) { + auto memrefType = allocaOp.getType(); + if (!isInRegisterSpace(memrefType)) + return WalkResult::advance(); + + auto regCount = getRegisterCount(memrefType); + if (failed(regCount)) { + allocaOp->emitError( + "Cannot allocate dynamic-sized memref in register space"); + return WalkResult::interrupt(); + } + + regCounts.emplace_back(*regCount, allocaOp); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) + return signalPassFailure(); + + // Sort by register size to reduce register alignment gaps. + llvm::stable_sort(regCounts, [](const std::pair &a, + const std::pair &b) { + return a.first < b.first; + }); + + // TODO: for now, just assign registers sequentially. In the future, + // we need a liveness analysis to assign registers. + unsigned nextRegister = 0; + + for (auto [regCount, op] : regCounts) { + // Align to regCount boundary. + nextRegister = ((nextRegister + regCount - 1) / regCount) * regCount; + + // Assign starting register number. + op->setAttr("water.vgpr_number", IntegerAttr::get(i32, nextRegister)); + + // Track how many registers this alloca uses. + op->setAttr("water.vgpr_count", IntegerAttr::get(i32, regCount)); + + // Advance to next available register. + nextRegister += regCount; + } + + // Attach metadata to function with total register count. + func->setAttr("water.total_vgprs", IntegerAttr::get(i32, nextRegister)); + } +}; + +} // namespace diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir new file mode 100644 index 0000000000..a8f82fce8f --- /dev/null +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -0,0 +1,624 @@ +// RUN: water-opt %s --water-insert-waitcnt | FileCheck %s + +// CHECK-LABEL: func.func @single_load_use +func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @two_loads_use_in_reverse_order +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} + +// CHECK-LABEL: func.func @lds_barriers +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @lds_barriers(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + amdgpu.lds_barrier + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + amdgpu.lds_barrier + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} + +// CHECK-LABEL: func.func @raw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @raw_dependency_memref +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: f32, %{{.*}}: index) +func.func @raw_dependency_memref(%memref: memref<1024xf32>, %data: f32, %offset: index) -> f32 { + // Store to memory + // CHECK: memref.store %[[DATA]], %[[MEM]] + memref.store %data, %memref[%offset] : memref<1024xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]] + %result = memref.load %memref[%offset] : memref<1024xf32> + + // CHECK: return %[[LOAD]] + return %result : f32 +} + +// CHECK-LABEL: func.func @war_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Load from memory + // CHECK: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to same memory - WAR dependency, must wait for load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @waw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA1:.*]]: vector<4xf32>, %[[DATA2:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @waw_dependency(%memref: memref<1024xf32>, %data1: vector<4xf32>, %data2: vector<4xf32>, %offset: index) { + // First store + // CHECK: vector.store %[[DATA1]], %[[MEM]] + vector.store %data1, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Second store to same memory - WAW dependency, must wait for first store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA2]], %[[MEM]] + vector.store %data2, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @raw_dependency_non_zero_waitcnt +func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate two distinct memrefs to guarantee no aliasing + // CHECK: %[[MEM_A:.*]] = memref.alloc() + %memrefA = memref.alloc() : memref<1024xf32> + // CHECK: %[[MEM_B:.*]] = memref.alloc() + %memrefB = memref.alloc() : memref<1024xf32> + + // Store to memory A + // CHECK: vector.store %{{.*}}, %[[MEM_A]] + vector.store %data, %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to memory B (intervening operation, different memref) + // CHECK: vector.store %{{.*}}, %[[MEM_B]] + vector.store %data, %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from memory A - RAW dependency with store to A at distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM_A]] + %result = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @workgroup_memory_raw +func.func @workgroup_memory_raw(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate workgroup (LDS) memory + // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Store to LDS + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS - RAW dependency, should use dsCnt not loadCnt + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[LDS]] + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @mixed_global_and_workgroup +// CHECK-SAME: (%[[GLOBAL:.*]]: memref<1024xf32>, %[[LDS:.*]]: memref<1024xf32, #gpu.address_space>, %{{.*}}: vector<4xf32>, %{{.*}}: index) +func.func @mixed_global_and_workgroup(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to global memory + // CHECK: vector.store %{{.*}}, %[[GLOBAL]] + vector.store %data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (different counter, no dependency) + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from global - RAW dependency with global store at distance 0 + // (LDS store doesn't count because it's a different counter type) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[GLOBAL]] + %result = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @existing_waitcnt +func.func @existing_waitcnt(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // CHECK: amdgpu.memory_counter_wait load(0) + amdgpu.memory_counter_wait load(0) + + // Another store after the wait + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load requires wait for the second store only (first was already waited on) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @existing_waitcnt_more_strict +func.func @existing_waitcnt_more_strict(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Store to memory + // CHECK: vector.store + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // Normally, the distance will be 1, but explicit amdgpu.memory_counter_wait + // overrides it. + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + amdgpu.memory_counter_wait load(0) + + // CHECK: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + + +// CHECK-LABEL: func.func @control_flow_merge +func.func @control_flow_merge(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // Extra operation in this path + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + // No extra operations, just branch to merge point + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // bb1 branch has distance 1 but bb2 has distance 0, so we need to conservatively + // take 0 + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @control_flow_merge_same_lists +func.func @control_flow_merge_same_lists(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // both branches has the same distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @loop_carried_dependency +func.func @loop_carried_dependency(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Store in each iteration + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load in the same iteration - RAW dependency with store from this iteration + // In steady state, the backedge brings pending operations from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses the load result, which is async, so need to wait for it + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @loop_load_before_store +func.func @loop_load_before_store(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Load first - in steady state, has RAW dependency with store from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store after load - WAR dependency with load in same iteration + // The wait for the load clears it from pending, so this wait is for the load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses load result - load was already waited on by the store, no additional wait needed + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @memref_copy_raw_source +func.func @memref_copy_raw_source(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy from source - RAW dependency (reads from source that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_waw_target +func.func @memref_copy_waw_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAW dependency (writes to target that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_war_target +func.func @memref_copy_war_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // Load from destination + // CHECK: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAR dependency (writes to target that was just read) + // The copy's wait also synchronizes the load, so return doesn't need another wait + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @memref_copy_both_dependencies +func.func @memref_copy_both_dependencies(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy needs to wait for both stores: + // - RAW on source (copy reads from source) + // - WAW on target (copy writes to destination) + // Both stores alias with their respective memrefs, so we need wait(0) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // Load from destination after copy - RAW dependency with copy + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @gather_to_lds +func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %src_offset: index, %dst_offset: index) -> vector<4xf32> { + // Store to global memory + // CHECK: vector.store + vector.store %data, %global[%src_offset] : memref<1024xf32>, vector<4xf32> + + // Gather from global to LDS - has both RAW (reads from global) and acts as store to LDS + // Should wait for global store using load counter + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.gather_to_lds + amdgpu.gather_to_lds %global[%src_offset], %lds[%dst_offset] : f32, memref<1024xf32>, memref<1024xf32, #gpu.address_space> + + // Load from LDS - RAW dependency with gather writing to LDS + // Should wait for LDS operation using ds counter + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %lds[%dst_offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @double_buffering +func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Cannot skip unfortunately + // CHECK: amdgpu.memory_counter_wait load(0) ds(0) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @triple_buffering +func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the prev copy + // CHECK: amdgpu.memory_counter_wait load(0) ds(1) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + + +// CHECK-LABEL: func.func @triple_buffering_reg_space +func.func @triple_buffering_reg_space(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %reg = memref.alloca() : memref<4xf32, 128 : i32> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the the prev copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.subview + %subview = memref.subview %current[%offset] [4] [1] : memref<1024xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + + // This copy only depends on buffer 2 iterations ago + // CHECK: amdgpu.memory_counter_wait ds(2) + // CHECK: memref.copy + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @load_store_repeated +func.func @load_store_repeated(%src0: memref<4xf32>, %src1: memref<4xf32>, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<4xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<4xf32, #gpu.address_space> + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + %reg3 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK-COUNT-4: memref.copy + memref.copy %src0, %reg0 : memref<4xf32> to memref<4xf32, 128 : i32> + memref.copy %src1, %reg1 : memref<4xf32> to memref<4xf32, 128 : i32> + + memref.copy %buff0, %reg2 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + memref.copy %buff1, %reg3 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: vector.load + %data0 = vector.load %reg0[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.load + %data1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK-NEXT: vector.load + %data2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: vector.load + %data3 = vector.load %reg3[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + return +} diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir new file mode 100644 index 0000000000..0458bc81ea --- /dev/null +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -0,0 +1,557 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx1200}))' | FileCheck %s --check-prefixes=CHECK,GFX12 + +// Test lowering of vector memory operations to AMDGPU global_load/store inline assembly + +// CHECK-LABEL: func.func @simple_function +func.func @simple_function(%arg0: f32) -> f32 { + // CHECK: return %arg0 + return %arg0 : f32 +} + +// CHECK-LABEL: func.func @vector_load +func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store +func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return +} + +// CHECK-LABEL: func.func @vector_load_b32 +func.func @vector_load_b32(%memref: memref<1024xf32>, %offset: index) -> vector<1xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @vector_load_b64 +func.func @vector_load_b64(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @vector_load_b96 +func.func @vector_load_b96(%memref: memref<1024xf32>, %offset: index) -> vector<3xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx3 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @vector_load_b128 +func.func @vector_load_b128(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store_b32 +func.func @vector_store_b32(%memref: memref<1024xf32>, %offset: index, %data: vector<1xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b64 +func.func @vector_store_b64(%memref: memref<1024xf32>, %offset: index, %data: vector<2xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b96 +func.func @vector_store_b96(%memref: memref<1024xf32>, %offset: index, %data: vector<3xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx3 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b128 +func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @load_store_sequence +func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { + // Test lowering of load/store sequence + + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> + + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} + +// ----- +// Buffer operations tests + +// CHECK-LABEL: func.func @buffer_load_b32 +func.func @buffer_load_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<1xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b64 +func.func @buffer_load_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<2xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b64 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b96 +func.func @buffer_load_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<3xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b96 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b128 +func.func @buffer_load_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<4xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @buffer_store_b32 +func.func @buffer_store_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<1xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b64 +func.func @buffer_store_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<2xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b64 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b96 +func.func @buffer_store_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<3xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b96 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b128 +func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<4xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_and_buffer +func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) { + // Load from global memory (should use global_load) + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to buffer memory (should use buffer_store) + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %global_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Load from buffer memory (should use buffer_load) + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" + %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Store to global memory (should use global_store) + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %buffer_data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + return +} +// ----- +// DS operations tests + +// CHECK-LABEL: func.func @ds_load_b32 +func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<1xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @ds_load_b64 +func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<2xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b64 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @ds_load_b96 +func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<3xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b96 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @ds_load_b128 +func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<4xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @ds_store_b32 +func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<1xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b64 +func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<2xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b64 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b96 +func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<3xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b96 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b128 +func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<4xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_buffer_and_ds +func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %lds: memref<1024xf32, #gpu.address_space>, %offset: index) { + // Load from global (should use global_load) + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (should use ds_write) + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %global_data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS (should use ds_read) + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Store to buffer (should use buffer_store) + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %lds_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + return +} + +// ----- +// Scalar (memref) operations tests + +// CHECK-LABEL: func.func @scalar_load_global_f32 +func.func @scalar_load_global_f32(%memref: memref<1024xf32>, %offset: index) -> f32 { + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf32> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_load_global_f64 +func.func @scalar_load_global_f64(%memref: memref<1024xf64>, %offset: index) -> f64 { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf64> + return %result : f64 +} + +// CHECK-LABEL: func.func @scalar_store_global_f32 +func.func @scalar_store_global_f32(%memref: memref<1024xf32>, %offset: index, %data: f32) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf32> + return +} + +// CHECK-LABEL: func.func @scalar_store_global_f64 +func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %data: f64) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf64> + return +} + +// CHECK-LABEL: func.func @scalar_load_buffer_f32 +func.func @scalar_load_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> f32 { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" + %result = memref.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_buffer_f32 +func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: f32) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" + memref.store %data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_load_ds_f32 +func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> f32 { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = memref.load %lds[%offset] : memref<1024xf32, #gpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_ds_f32 +func.func @scalar_store_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: f32) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + memref.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @mixed_scalar_and_vector +func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { + // Scalar load + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %scalar = memref.load %memref[%offset] : memref<1024xf32> + + // Vector load + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %vector = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Scalar store + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %scalar, %memref[%offset] : memref<1024xf32> + + // Vector store + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %vector, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + return +} + +// Test copy to register space with pre-numbered allocas + +// CHECK-LABEL: func.func @copy_global_to_reg_scalar +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %val : f32 +} + +// CHECK-LABEL: func.func @copy_global_to_reg_vector +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_buffer_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_workgroup_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @store_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" + // GFX12: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" + memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" + %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %result : f32 +} + +// CHECK-LABEL: func.func @multiple_reg_allocas +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { + %c0 = arith.constant 0 : index + %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v247},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[248:251]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v247", "={v247}" + // GFX12: llvm.inline_asm "; reg_load v247", "={v247}" + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" + // GFX12: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" + %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> +} + +// ----- +// Test MFMA hazard handling with s_nop insertion + +// CHECK-LABEL: func.func @mfma_hazard_store +func.func @mfma_hazard_store(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // Perform MFMA operation + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Store MFMA result - should trigger hazard handling + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} + +// CHECK-LABEL: func.func @mfma_hazard_with_extract +func.func @mfma_hazard_with_extract(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // MFMA with vector extract - hazard checking should propagate through extract + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + %extracted = vector.extract %result[0] : f32 from vector<4xf32> + + // Store extracted value - should still detect hazard through propagation + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %extracted, %arg0[%offset] : memref<1024xf32> + + return +} + +// CHECK-LABEL: func.func @no_hazard_with_existing_nop +func.func @no_hazard_with_existing_nop(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Manually insert s.nop + %nop_count = arith.constant 4 : i16 + llvm.call_intrinsic "llvm.amdgcn.s.nop"(%nop_count) : (i16) -> () + + // Store should NOT insert another s.nop since one already exists + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // CHECK-NOT: rocdl.sched.barrier + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir new file mode 100644 index 0000000000..7f789e8c9e --- /dev/null +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -0,0 +1,165 @@ +// RUN: water-opt %s --water-materialize-reg-copy | FileCheck %s + +// CHECK-LABEL: func @test_simple_load +func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x1xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j] : memref<10x20xf32> + return %0 : f32 +} + +// CHECK-LABEL: func @test_simple_vector_load +func.func @test_simple_vector_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> vector<4xf32> { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 4] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x4xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x4xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = vector.load %arg0[%i, %j] : memref<10x20xf32>, vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @test_1d_load +func.func @test_1d_load(%arg0: memref<100xf16>, %i: index) -> f16 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1] [1] [1] + // CHECK-SAME: memref<100xf16> to memref<1xf16, strided<[1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf16, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i] : memref<100xf16> + return %0 : f16 +} + +// CHECK-LABEL: func @test_3d_load +func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: index) -> i32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] + // CHECK-SAME: memref<8x16x32xi32> to memref<1x1x1xi32, strided<[512, 32, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1x1xi32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j, %k] : memref<8x16x32xi32> + return %0 : i32 +} + +// CHECK-LABEL: func @test_multiple_loads +func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) -> f32 { + // First load: subview, alloca, copy + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + %0 = memref.load %arg0[%i, %j] : memref<10x10xf32> + + // Second load: subview, alloca, copy + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + %1 = memref.load %arg0[%j, %i] : memref<10x10xf32> + + // Now the actual loads happen right before the addf (late as possible) + // CHECK: memref.load + // CHECK: memref.load + // CHECK: arith.addf + %2 = arith.addf %0, %1 : f32 + return %2 : f32 +} + +// CHECK-LABEL: func @test_skip_memspace_128 +func.func @test_skip_memspace_128(%arg0: memref<10xf32>, %arg1: memref<5xf32, 128 : i32>, %i: index) -> f32 { + // This load should be transformed (from default memspace) + // First: subview, alloca, copy + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %0 = memref.load %arg0[%i] : memref<10xf32> + + // This load should NOT be transformed (already from memspace 128) + // It stays in place + // CHECK: %[[VAL1:.*]] = memref.load %arg1[%arg2] : memref<5xf32, 128 : i32> + %1 = memref.load %arg1[%i] : memref<5xf32, 128 : i32> + + // The load from temp happens late (right before addf) + // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] + // Note: operands may be reordered + // CHECK: arith.addf + %result = arith.addf %0, %1 : f32 + // CHECK: return + return %result : f32 +} + +// CHECK-LABEL: func @test_control_flow +func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 { + // Load happens once, but value is used in multiple blocks + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %val = memref.load %arg0[%i] : memref<10xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // First block: load happens here before the addf + // CHECK: ^bb1: + // CHECK: %[[CONST1:.*]] = arith.constant 1.0 + // CHECK: %[[LOAD1:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[ADD1:.*]] = arith.addf %[[LOAD1]], %[[CONST1]] + %c1 = arith.constant 1.0 : f32 + %sum1 = arith.addf %val, %c1 : f32 + // CHECK: cf.br ^bb3(%[[ADD1]] + cf.br ^bb3(%sum1 : f32) + +^bb2: + // Second block: another load happens here before the mulf + // CHECK: ^bb2: + // CHECK: %[[CONST2:.*]] = arith.constant 2.0 + // CHECK: %[[LOAD2:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[MUL:.*]] = arith.mulf %[[LOAD2]], %[[CONST2]] + %c2 = arith.constant 2.0 : f32 + %prod = arith.mulf %val, %c2 : f32 + // CHECK: cf.br ^bb3(%[[MUL]] + cf.br ^bb3(%prod : f32) + +^bb3(%result: f32): + // CHECK: ^bb3(%[[RESULT:.*]]: f32): + // CHECK: return %[[RESULT]] + return %result : f32 +} + +// CHECK-LABEL: func @test_loop_hoist +func.func @test_loop_hoist(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index, %init: f32) -> f32 { + %c0 = arith.constant 0 : index + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: arith.constant 0 : index + // CHECK: memref.store %arg4, %[[ALLOCA]] + // CHECK: scf.for %[[IV:.*]] = %arg1 to %arg2 step %arg3 iter_args(%[[ITER_ARG:.*]] = %arg4) + %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (f32) { + // CHECK: memref.load %[[ALLOCA]] + // CHECK: memref.store %{{.*}}, %arg0[%c0] + memref.store %arg, %arg0[%c0] : memref<100xf32> + %alloca = memref.alloca() : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %alloca : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %alloca[%c0] : memref<1xf32, 128 : i32> + // CHECK: memref.subview + // CHECK: memref.copy + // CHECK: memref.load %[[ALLOCA]] + // CHECK: scf.yield + scf.yield %val : f32 + } + // CHECK: memref.load %[[ALLOCA]] + // CHECK: return + return %result : f32 +} diff --git a/water/test/Transforms/number-registers-error.mlir b/water/test/Transforms/number-registers-error.mlir new file mode 100644 index 0000000000..4d9f77435f --- /dev/null +++ b/water/test/Transforms/number-registers-error.mlir @@ -0,0 +1,7 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' --verify-diagnostics + +func.func @test_dynamic_size_error(%n: index) { + // expected-error @+1 {{Cannot allocate dynamic-sized memref in register space}} + %reg = memref.alloca(%n) : memref + return +} diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir new file mode 100644 index 0000000000..ccc32447bb --- /dev/null +++ b/water/test/Transforms/number-registers.mlir @@ -0,0 +1,80 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s + +// CHECK-LABEL: func @test_simple_numbering +// CHECK-SAME: attributes {water.total_vgprs = 8 : i32} +func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { + %c0 = arith.constant 0 : index + + // 1xf32 = 4 bytes = 1 register, starts at reg 0 + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg0 = memref.alloca() : memref<1xf32, 128 : i32> + + // 4xf32 = 16 bytes = 4 registers, starts at reg 4 + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // 1xf32 = 4 bytes = 1 register, starts at reg 1 (after reg0) + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 1 : i32} + %reg2 = memref.alloca() : memref<1xf32, 128 : i32> + + %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + + return %val0 : f32 +} + +// CHECK-LABEL: func @test_loop_with_registers +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} +func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + + // Register allocated outside loop + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + scf.for %iv = %lb to %ub step %step { + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + memref.store %val, %arg0[%iv] : memref<100xf32> + } + + return +} + +// CHECK-LABEL: func @test_triple_buffering_numbering +// CHECK-SAME: attributes {water.total_vgprs = 12 : i32} +func.func @test_triple_buffering_numbering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + + // Three registers for triple buffering, each 4xf32 = 4 registers + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 0 : i32} + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 8 : i32} + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + + return +} + +// CHECK-LABEL: func @test_mixed_memspaces +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} +func.func @test_mixed_memspaces(%arg0: memref<100xf32>) { + %c0 = arith.constant 0 : index + + // Non-register space alloca - should not be numbered + // CHECK: memref.alloca() : memref<10xf32> + // CHECK-NOT: water.vgpr_number + %local = memref.alloca() : memref<10xf32> + + // Register space alloca - should be numbered + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + return +} From ad20fc123d3e7ae6143704baf151a56634a2ad0f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 21:27:48 +0100 Subject: [PATCH 02/38] del Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 55 - water/lib/Transforms/CMakeLists.txt | 3 - water/lib/Transforms/WaterLowerMemoryOps.cpp | 1121 ----------------- .../Transforms/WaterMaterializeRegCopy.cpp | 268 ---- water/lib/Transforms/WaterNumberRegisters.cpp | 109 -- water/test/Transforms/lower-memory-ops.mlir | 557 -------- .../test/Transforms/materialize-reg-copy.mlir | 165 --- .../Transforms/number-registers-error.mlir | 7 - water/test/Transforms/number-registers.mlir | 80 -- 9 files changed, 2365 deletions(-) delete mode 100644 water/lib/Transforms/WaterLowerMemoryOps.cpp delete mode 100644 water/lib/Transforms/WaterMaterializeRegCopy.cpp delete mode 100644 water/lib/Transforms/WaterNumberRegisters.cpp delete mode 100644 water/test/Transforms/lower-memory-ops.mlir delete mode 100644 water/test/Transforms/materialize-reg-copy.mlir delete mode 100644 water/test/Transforms/number-registers-error.mlir delete mode 100644 water/test/Transforms/number-registers.mlir diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 832771320f..fb83867037 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -190,59 +190,4 @@ def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { ]; } -def WaterLowerMemoryOps : InterfacePass<"water-lower-memory-ops", "::mlir::FunctionOpInterface"> { - let summary = "Lower high-level memory operations to AMDGPU dialect"; - let description = [{ - This pass lowers high-level memory operations (vector.load, vector.store, - memref operations) to AMDGPU-specific memory operations (buffer loads/stores, - LDS operations, etc.). - - This lowering prepares the IR for subsequent waitcnt insertion and - final code generation. - }]; - let dependentDialects = [ - "::mlir::amdgpu::AMDGPUDialect", - "::mlir::gpu::GPUDialect", - "::mlir::LLVM::LLVMDialect", - "::mlir::memref::MemRefDialect", - "::mlir::ROCDL::ROCDLDialect", - "::mlir::vector::VectorDialect", - ]; - let options = [ - Option<"chipset", "chipset", "std::string", [{""}], - "Target chipset (e.g., gfx942, gfx1100)"> - ]; -} - -def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { - let summary = "Materialize register copies for loads"; - let description = [{ - This pass materializes explicit register copies by transforming load - operations to route through a temporary buffer in the virtual register - memory space (memspace 128). For each load: - 1. Creates a subview of the source memref at the load indices - 2. Allocates a temporary buffer in memory space 128 (virtual register space) - 3. Copies from the subview to the temporary register buffer - 4. Loads from the temporary register buffer - - This transformation makes register traffic explicit in the IR, enabling - better analysis and optimization of register usage patterns. - }]; - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::memref::MemRefDialect", - ]; -} - -def WaterNumberRegisters : InterfacePass<"water-number-registers", "::mlir::FunctionOpInterface"> { - let summary = "Assign physical registers to register space allocas"; - let description = [{ - This pass performs register allocation by assigning physical register numbers - to memref.alloca operations in memory space 128 (virtual register space). - }]; - let dependentDialects = [ - "::mlir::memref::MemRefDialect", - ]; -} - #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index 4c14676fb7..0a3972cb76 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -9,9 +9,6 @@ add_mlir_dialect_library(MLIRWaterTransforms MemrefDecomposition.cpp SLPVectorizer.cpp WaterInsertWaitcnt.cpp - WaterLowerMemoryOps.cpp - WaterMaterializeRegCopy.cpp - WaterNumberRegisters.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp deleted file mode 100644 index 3fd4e6ad61..0000000000 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ /dev/null @@ -1,1121 +0,0 @@ -// Copyright 2025 The Wave Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "water/Transforms/Passes.h" - -#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; - -namespace mlir::water { -#define GEN_PASS_DEF_WATERLOWERMEMORYOPS -#include "water/Transforms/Passes.h.inc" -} // namespace mlir::water - -namespace { - -static unsigned getBitwidth(ShapedType type) { - assert(type.hasStaticShape() && "Shaped type must have static shape"); - return type.getNumElements() * type.getElementTypeBitWidth(); -} - -static unsigned getBitwidth(Type type) { - if (auto shaped = dyn_cast(type)) - return getBitwidth(shaped); - - return type.getIntOrFloatBitWidth(); -} - -static std::string getVGPRRange(unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount) { - assert(vgprCount > 0 && "VGPR count must be greater than 0"); - unsigned start = vgprOffset + vgprNum; - if (vgprCount == 1) { - return ("v" + llvm::Twine(start)).str(); - } else { - unsigned end = start + vgprCount - 1; - return ("v[" + llvm::Twine(start) + ":" + llvm::Twine(end) + "]").str(); - } -} - -static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount, bool isOutput) { - return (llvm::Twine(isOutput ? "=" : "") + "{" + - getVGPRRange(vgprOffset, vgprNum, vgprCount) + "}") - .str(); -} - -static FailureOr getLoadSizeSuffixRDNA(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("u8"); - case 16: - return StringRef("u16"); - case 32: - return StringRef("b32"); - case 64: - return StringRef("b64"); - case 96: - return StringRef("b96"); - case 128: - return StringRef("b128"); - default: - return failure(); - } -} - -static FailureOr getStoreSizeSuffixRDNA(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("b8"); - case 16: - return StringRef("b16"); - case 32: - return StringRef("b32"); - case 64: - return StringRef("b64"); - case 96: - return StringRef("b96"); - case 128: - return StringRef("b128"); - default: - return failure(); - } -} - -static FailureOr getLoadSizeSuffixCDNA(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("ubyte"); - case 16: - return StringRef("ushort"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); - } -} - -static FailureOr getStoreSizeSuffixCDNA(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("byte"); - case 16: - return StringRef("short"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); - } -} - -static FailureOr getBufferLoadSuffix(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - return getLoadSizeSuffixRDNA(bitWidth); - } else { - return getLoadSizeSuffixCDNA(bitWidth); - } -} - -static FailureOr getBufferStoreSuffix(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - return getStoreSizeSuffixRDNA(bitWidth); - } else { - return getStoreSizeSuffixCDNA(bitWidth); - } -} - -static FailureOr getGlobalLoadSuffix(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - return getLoadSizeSuffixRDNA(bitWidth); - } else { - return getLoadSizeSuffixCDNA(bitWidth); - } -} - -static FailureOr getGlobalStoreSuffix(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - return getStoreSizeSuffixRDNA(bitWidth); - } else { - return getStoreSizeSuffixCDNA(bitWidth); - } -} - -static FailureOr getDSLoadSuffix(unsigned bitWidth, - bool /*isRDNAArch*/) { - return getLoadSizeSuffixRDNA(bitWidth); -} - -static FailureOr getDSStoreSuffix(unsigned bitWidth, - bool /*isRDNAArch*/) { - return getStoreSizeSuffixRDNA(bitWidth); -} - -/// Create an LLVM inline assembly operation with standard attributes -static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, - TypeRange resultTypes, - ValueRange operands, StringRef asmStr, - StringRef constraints, - bool hasSideEffects) { - return LLVM::InlineAsmOp::create( - rewriter, loc, resultTypes, operands, asmStr, constraints, hasSideEffects, - /*is_align_stack=*/false, - /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, - /*asm_dialect=*/LLVM::AsmDialectAttr{}, - /*operand_attrs=*/ArrayAttr{}); -} - -/// Detect if chipset is RDNA vs CDNA architecture -static bool isRDNA(const amdgpu::Chipset &chipset) { - return chipset.majorVersion != 9; -} - -static Operation *propagateExtract(Operation *op) { - if (auto extract = dyn_cast(op)) - return extract.getSource().getDefiningOp(); - if (auto extract = dyn_cast(op)) - return extract.getSource().getDefiningOp(); - return nullptr; -} - -static unsigned checkHazards(Operation *currentOp, Value value) { - Operation *op = value.getDefiningOp(); - if (!op) - return 0; - - while (auto nextOp = propagateExtract(op)) - op = nextOp; - - if (op->getBlock() != currentOp->getBlock()) - return 0; - - if (!isa(op)) - return 0; - - while (op != currentOp) { - if (isa(op) && - cast(op).getIntrin() == "llvm.amdgcn.s.nop") - return 0; - op = op->getNextNode(); - } - - return 5; // HACK for now -} - -static void handleHazards(IRRewriter &rewriter, Location loc, Operation *op, - Value value) { - unsigned hazard = checkHazards(op, value); - if (hazard > 0) { - ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); - Value nopCount = - arith::ConstantIntOp::create(rewriter, loc, hazard - 1, 16); - StringAttr intrin = rewriter.getStringAttr("llvm.amdgcn.s.nop"); - LLVM::CallIntrinsicOp::create(rewriter, loc, {}, intrin, nopCount); - } -} - -/// Compute byte offset as iX for a memref access with indices -template -static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, - Value memref, ValueRange indices, - unsigned elementBitWidth) { - // Extract strided metadata to get offset and strides - auto metadataOp = - memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); - Value offset = metadataOp.getOffset(); - - // Compute linear index from multidimensional indices - Value linearIndex = offset; - for (auto i : llvm::seq(0, indices.size())) { - Value stride = metadataOp.getStrides()[i]; - Value indexTimesStride = arith::MulIOp::create( - rewriter, loc, indices[i], stride, arith::IntegerOverflowFlags::nsw); - linearIndex = - arith::AddIOp::create(rewriter, loc, linearIndex, indexTimesStride, - arith::IntegerOverflowFlags::nsw); - } - - // Convert linear index to byte offset - unsigned elementBytes = elementBitWidth / 8; - Value elementSize = - arith::ConstantIndexOp::create(rewriter, loc, elementBytes); - Value byteOffset = - arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, - arith::IntegerOverflowFlags::nsw); - - Type indexType = IntegerType::get(rewriter.getContext(), Bits); - return arith::IndexCastOp::create(rewriter, loc, indexType, byteOffset); -} - -/// Compute the final address for a memref access with indices (for global -/// operations) -template -static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, - Value memref, ValueRange indices, - unsigned elementBitWidth) { - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), MemSpace); - auto intType = rewriter.getIntegerType(Bits); - - // Extract base pointer - auto metadataOp = - memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); - Value basePtr = metadataOp.getBaseBuffer(); - - // Convert base pointer to i64 - Value basePtrInt = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); - basePtrInt = arith::IndexCastOp::create(rewriter, loc, intType, basePtrInt); - - // Compute byte offset - Value byteOffsetI64 = computeMemrefByteOffset(rewriter, loc, memref, - indices, elementBitWidth); - - // Add byte offset to base pointer - Value finalAddr = - arith::AddIOp::create(rewriter, loc, basePtrInt, byteOffsetI64, - arith::IntegerOverflowFlags::nsw); - return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); -} - -/// Extract buffer descriptor and base offset from a fat_raw_buffer memref -/// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) -/// Returns: {resource descriptor (i128), base offset (i32)} -static std::pair -extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { - // Create proper memref descriptor struct type: {ptr, ptr, offset, - // sizes[rank], strides[rank]} - auto memrefType = cast(memref.getType()); - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); - auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); - auto arrayType = LLVM::LLVMArrayType::get(i64Type, memrefType.getRank()); - Type descriptorFields[] = {ptrType, ptrType, i64Type, arrayType, arrayType}; - - auto memrefDescType = - LLVM::LLVMStructType::getLiteral(rewriter.getContext(), descriptorFields); - - Value memrefDescVal = - UnrealizedConversionCastOp::create(rewriter, loc, memrefDescType, memref) - .getResult(0); - - MemRefDescriptor memrefDesc(memrefDescVal); - Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); - - // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 - // offset} - auto i160Type = IntegerType::get(rewriter.getContext(), 160); - Value fullDesc = LLVM::PtrToIntOp::create(rewriter, loc, i160Type, bufferPtr); - - // Extract lower 32 bits for base offset - Value baseOffset = arith::TruncIOp::create(rewriter, loc, i32Type, fullDesc); - - // Extract upper 128 bits for resource descriptor - auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); - Value rsrcBits160 = arith::ShRUIOp::create(rewriter, loc, fullDesc, c32); - auto i128Type = IntegerType::get(rewriter.getContext(), 128); - Value rsrcBits = - arith::TruncIOp::create(rewriter, loc, i128Type, rsrcBits160); - - return {rsrcBits, baseOffset}; -} - -/// Helper to get memref, result type, and bit width from load operation -template -static std::tuple getLoadOpInfo(LoadOpTy loadOp) { - if constexpr (std::is_same_v) { - auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = getBitwidth(vectorType); - return {loadOp.getBase(), vectorType, bitWidth}; - } else { - auto elementType = loadOp.getResult().getType(); - unsigned bitWidth = getBitwidth(elementType); - return {loadOp.getMemRef(), elementType, bitWidth}; - } -} - -/// Helper to get memref, value type, and bit width from store operation -template -static std::tuple getStoreOpInfo(StoreOpTy storeOp) { - if constexpr (std::is_same_v) { - auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = getBitwidth(vectorType); - return {storeOp.getBase(), vectorType, bitWidth}; - } else { - auto elementType = storeOp.getValueToStore().getType(); - unsigned bitWidth = getBitwidth(elementType); - return {storeOp.getMemRef(), elementType, bitWidth}; - } -} - -/// Lower vector/scalar load to AMDGPU buffer load inline assembly -template -static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - - // TODO: for bitwidths less than 32, we will need to truncate the value to 32 - // immediately after the load, breaking the calculated dependencies. - // For now, just let llvm handle the loading - if (bitWidth < 32) - return success(); - - FailureOr suffix = getBufferLoadSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; - - Location loc = loadOp.getLoc(); - rewriter.setInsertionPoint(loadOp); - - // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" - std::string asmStr = - ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); - - // Constraints: "=v" for output (VGPR), "v" for offset (VGPR), "s" for - // descriptor (SGPR[4]) - StringRef constraints = "=v,v,s"; - - // Compute byte offset from indices - unsigned elementBitWidth = - std::is_same_v - ? cast(resultType).getElementTypeBitWidth() - : bitWidth; - Value offset = computeMemrefByteOffset<32>( - rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - - // Extract buffer descriptor and base offset from memref - auto [bufferDesc, baseOffset] = - extractBufferDescriptor(rewriter, loc, memref); - - // Add base offset to computed offset - Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, - arith::IntegerOverflowFlags::nsw); - - // Create inline assembly operation with result type directly - auto asmOp = createInlineAsm(rewriter, loc, resultType, - ValueRange{finalOffset, bufferDesc}, asmStr, - constraints, /*hasSideEffects=*/true); - - rewriter.replaceOp(loadOp, asmOp.getResult(0)); - return success(); -} - -/// Lower vector/scalar load to LLVM inline assembly (global_load_*) -template -static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - - if (bitWidth < 32) - return success(); - - FailureOr suffix = getGlobalLoadSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return loadOp.emitError("unsupported load bit width: ") << bitWidth; - - Location loc = loadOp.getLoc(); - - // Build the inline assembly string: "global_load_b64 $0, $1, off" - std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); - - // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) - StringRef constraints = "=v,v"; - - rewriter.setInsertionPoint(loadOp); - - // Compute the final address - unsigned elementBitWidth = - std::is_same_v - ? cast(resultType).getElementTypeBitWidth() - : bitWidth; - Value addr = computeMemrefAddress<64, 0>( - rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - - // Create the inline assembly operation with result type directly - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, - asmStr, constraints, /*hasSideEffects=*/true); - - rewriter.replaceOp(loadOp, asmOp.getResult(0)); - return success(); -} - -/// Lower vector/scalar load to AMDGPU DS load inline assembly -template -static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - - if (bitWidth < 32) - return success(); - - FailureOr suffix = getDSLoadSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; - - Location loc = loadOp.getLoc(); - rewriter.setInsertionPoint(loadOp); - - // Build inline assembly: "ds_read_b32 $0, $1" - std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); - - // Constraints: "=v" for output (VGPR), "v" for address (VGPR) - StringRef constraints = "=v,v"; - - // Compute byte offset as i64 - unsigned elementBitWidth = - std::is_same_v - ? cast(resultType).getElementTypeBitWidth() - : bitWidth; - Value offset = computeMemrefAddress<32, 3>( - rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - - // Create inline assembly operation (DS operations use 32-bit addresses) - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, - asmStr, constraints, /*hasSideEffects=*/true); - - rewriter.replaceOp(loadOp, asmOp.getResult(0)); - return success(); -} - -static Value extendToReg(Value value, IRRewriter &rewriter, Location loc) { - unsigned bitWidth = getBitwidth(value.getType()); - if (bitWidth >= 32) { - Type intType = rewriter.getIntegerType(bitWidth); - if (value.getType() != intType) - value = LLVM::BitcastOp::create(rewriter, loc, intType, value); - return value; - } - - // Sched barrier to prevent moving the expansion before the waitcnt. - ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); - - Type intType = rewriter.getIntegerType(bitWidth); - if (value.getType() != intType) - value = LLVM::BitcastOp::create(rewriter, loc, intType, value); - - return arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), value); -} - -/// Lower vector/scalar store to AMDGPU buffer store inline assembly -template -static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - - FailureOr suffix = getBufferStoreSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return storeOp.emitError("unsupported buffer store bit width: ") - << bitWidth; - - Location loc = storeOp.getLoc(); - rewriter.setInsertionPoint(storeOp); - handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); - - // Build inline assembly: "buffer_store_ $0, $1, $2, 0 offen" - std::string asmStr = - ("buffer_store_" + *suffix + " $0, $1, $2, 0 offen").str(); - - // Constraints: "v" for data (VGPR), "v" for offset (VGPR), "s" for descriptor - // (SGPR[4]) - StringRef constraints = "v,v,s"; - - // Compute byte offset from indices - unsigned elementBitWidth = - std::is_same_v - ? cast(valueType).getElementTypeBitWidth() - : bitWidth; - Value offset = computeMemrefByteOffset<32>( - rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - - // Extract buffer descriptor and base offset from memref - auto [bufferDesc, baseOffset] = - extractBufferDescriptor(rewriter, loc, memref); - - // Add base offset to computed offset - Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, - arith::IntegerOverflowFlags::nsw); - - Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); - - // Create inline assembly operation (no result for store) - createInlineAsm(rewriter, loc, TypeRange{}, - {valueToStore, finalOffset, bufferDesc}, asmStr, constraints, - /*hasSideEffects=*/true); - - rewriter.eraseOp(storeOp); - return success(); -} - -/// Lower vector/scalar store to LLVM inline assembly (global_store_*) -template -static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - - FailureOr suffix = getGlobalStoreSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return storeOp.emitError("unsupported store bit width: ") << bitWidth; - - Location loc = storeOp.getLoc(); - rewriter.setInsertionPoint(storeOp); - handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); - - // Build the inline assembly string: "global_store_b64 $0, $1, off" - std::string asmStr = ("global_store_" + *suffix + " $0, $1, off").str(); - - // Constraints: "v" for address (VGPR), "v" for data (VGPR) - StringRef constraints = "v,v"; - - // Compute the final address - unsigned elementBitWidth = - std::is_same_v - ? cast(valueType).getElementTypeBitWidth() - : bitWidth; - Value addr = computeMemrefAddress<64, 0>( - rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - - Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); - - // Create the inline assembly operation (no result for store) - createInlineAsm(rewriter, loc, {}, {addr, valueToStore}, asmStr, constraints, - /*hasSideEffects=*/true); - - rewriter.eraseOp(storeOp); - return success(); -} - -/// Lower vector/scalar store to AMDGPU DS store inline assembly -template -static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter, - bool isRDNAArch) { - auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - - FailureOr suffix = getDSStoreSuffix(bitWidth, isRDNAArch); - if (failed(suffix)) - return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; - - Location loc = storeOp.getLoc(); - rewriter.setInsertionPoint(storeOp); - handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); - - // Build inline assembly: "ds_write_b32 $0, $1" - std::string asmStr = ("ds_write_" + *suffix + " $0, $1").str(); - - // Constraints: "v" for address (VGPR), "v" for data (VGPR) - StringRef constraints = "v,v"; - - // Compute byte offset as i64 - unsigned elementBitWidth = - std::is_same_v - ? cast(valueType).getElementTypeBitWidth() - : bitWidth; - Value offset = computeMemrefAddress<32, 3>( - rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - - Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); - - // Create inline assembly operation (no result for store, DS uses 32-bit - // addresses) - createInlineAsm(rewriter, loc, {}, {offset, valueToStore}, asmStr, - constraints, - /*hasSideEffects=*/true); - - rewriter.eraseOp(storeOp); - return success(); -} - -/// Check if a memref uses AMDGPU fat_raw_buffer address space -static bool usesBufferAddressSpace(Value memref) { - auto memrefType = cast(memref.getType()); - auto memorySpace = memrefType.getMemorySpace(); - - if (!memorySpace) - return false; - - // Check for #amdgpu.address_space attribute - if (auto enumAttr = dyn_cast(memorySpace)) - return enumAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer; - - return false; -} - -/// Check if a memref uses workgroup (LDS) address space -static bool usesWorkgroupAddressSpace(Value memref) { - auto memrefType = cast(memref.getType()); - auto memorySpace = memrefType.getMemorySpace(); - - if (!memorySpace) - return false; - - // Check for #gpu.address_space attribute - if (auto enumAttr = dyn_cast(memorySpace)) - return enumAttr.getValue() == gpu::AddressSpace::Workgroup; - - return false; -} - -/// Check if a memref uses register space (memspace 128) -static bool usesRegisterSpace(Value memref) { - auto memrefType = cast(memref.getType()); - auto memorySpace = memrefType.getMemorySpace(); - - if (auto intAttr = dyn_cast_or_null(memorySpace)) - return intAttr.getInt() == 128; - - return false; -} - -/// Lower memref.copy when destination is in register space - buffer variant -static LogicalResult lowerCopyToRegBuffer(memref::CopyOp copyOp, - IRRewriter &rewriter, bool isRDNAArch, - unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount, - unsigned totalBits, Type resultType) { - Value src = copyOp.getSource(); - auto srcType = cast(src.getType()); - unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - - FailureOr suffix = getBufferLoadSuffix(totalBits, isRDNAArch); - if (failed(suffix)) - return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; - - Location loc = copyOp.getLoc(); - rewriter.setInsertionPoint(copyOp); - - // Compute byte offset (no indices for full copy) - Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, - elementBitWidth); - - // Extract buffer descriptor and base offset - auto [bufferDesc, baseOffset] = extractBufferDescriptor(rewriter, loc, src); - Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, - arith::IntegerOverflowFlags::nsw); - - // Build constraint with specific VGPR - std::string constraints = - getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v,s"; - - // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" - std::string asmStr = - ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); - - createInlineAsm(rewriter, loc, resultType, - ValueRange{finalOffset, bufferDesc}, asmStr, constraints, - /*hasSideEffects=*/true); - - rewriter.eraseOp(copyOp); - return success(); -} - -/// Lower memref.copy when destination is in register space - DS variant -static LogicalResult lowerCopyToRegDS(memref::CopyOp copyOp, - IRRewriter &rewriter, bool isRDNAArch, - unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount, unsigned totalBits, - Type resultType) { - Value src = copyOp.getSource(); - auto srcType = cast(src.getType()); - unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - - FailureOr suffix = getDSLoadSuffix(totalBits, isRDNAArch); - if (failed(suffix)) - return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; - - Location loc = copyOp.getLoc(); - rewriter.setInsertionPoint(copyOp); - - // Compute byte offset - Value offset = computeMemrefAddress<32, 3>(rewriter, loc, src, /*indices=*/{}, - elementBitWidth); - - // Build constraint with specific VGPR - std::string constraints = - getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; - - // Build inline assembly: "ds_read_b32 $0, $1" - std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); - - createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, asmStr, - constraints, /*hasSideEffects=*/true); - - rewriter.eraseOp(copyOp); - return success(); -} - -/// Lower memref.copy when destination is in register space - global variant -static LogicalResult lowerCopyToRegGlobal(memref::CopyOp copyOp, - IRRewriter &rewriter, bool isRDNAArch, - unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount, - unsigned totalBits, Type resultType) { - Value src = copyOp.getSource(); - auto srcType = cast(src.getType()); - unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - - FailureOr suffix = getGlobalLoadSuffix(totalBits, isRDNAArch); - if (failed(suffix)) - return copyOp.emitError("unsupported copy bit width: ") << totalBits; - - Location loc = copyOp.getLoc(); - rewriter.setInsertionPoint(copyOp); - - // Compute source address - Value addr = computeMemrefAddress<64, 0>(rewriter, loc, src, /*indices=*/{}, - elementBitWidth); - - // Build constraint with specific VGPR - std::string constraints = - getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; - - // Build inline assembly: "global_load_b128 $0, $1, off" - std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); - - createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, asmStr, - constraints, /*hasSideEffects=*/true); - - rewriter.eraseOp(copyOp); - return success(); -} - -/// Lower memref.copy when destination is in register space -static LogicalResult lowerCopyToReg(memref::CopyOp copyOp, IRRewriter &rewriter, - bool isRDNAArch, unsigned vgprOffset) { - Value src = copyOp.getSource(); - Value dst = copyOp.getTarget(); - - // Get destination alloca to find VGPR assignment - auto dstAlloca = dst.getDefiningOp(); - if (!dstAlloca) - return copyOp.emitError("destination must be a memref.alloca"); - - // Get VGPR number from destination alloca - auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); - auto vgprCountAttr = - dstAlloca->getAttrOfType("water.vgpr_count"); - if (!vgprNumAttr || !vgprCountAttr) - return copyOp.emitError("destination alloca missing VGPR attributes"); - - unsigned vgprNum = vgprNumAttr.getInt(); - unsigned vgprCount = vgprCountAttr.getInt(); - - // Get source type info - auto srcType = cast(src.getType()); - if (!srcType.hasStaticShape()) - return copyOp.emitError("source must have static shape"); - - unsigned totalBits = getBitwidth(srcType); - - // Get result type from destination - auto dstType = cast(dst.getType()); - if (!dstType.hasStaticShape()) - return copyOp.emitError("destination must have static shape"); - - unsigned resultBitWidth = getBitwidth(dstType); - unsigned resultNumElements = (resultBitWidth + 31) / 32; - Type resultType = - VectorType::get(resultNumElements, rewriter.getIntegerType(32)); - - // Dispatch based on source memory space - if (usesBufferAddressSpace(src)) - return lowerCopyToRegBuffer(copyOp, rewriter, isRDNAArch, vgprOffset, - vgprNum, vgprCount, totalBits, resultType); - if (usesWorkgroupAddressSpace(src)) - return lowerCopyToRegDS(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, - vgprCount, totalBits, resultType); - return lowerCopyToRegGlobal(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, - vgprCount, totalBits, resultType); -} - -/// Lower load from register space to inline assembly -template -static LogicalResult lowerLoadFromReg(LoadOpTy loadOp, IRRewriter &rewriter, - unsigned vgprOffset) { - Value memref; - if constexpr (std::is_same_v) - memref = loadOp.getBase(); - else - memref = loadOp.getMemRef(); - - // Get source alloca to find VGPR assignment - auto srcAlloca = memref.getDefiningOp(); - if (!srcAlloca) - return loadOp.emitError("source must be a memref.alloca"); - - // Get VGPR number from source alloca - auto vgprNumAttr = srcAlloca->getAttrOfType("water.vgpr_number"); - auto vgprCountAttr = - srcAlloca->getAttrOfType("water.vgpr_count"); - if (!vgprNumAttr || !vgprCountAttr) - return loadOp.emitError("source alloca missing VGPR attributes"); - - unsigned vgprNum = vgprNumAttr.getInt(); - unsigned vgprCount = vgprCountAttr.getInt(); - - Location loc = loadOp.getLoc(); - rewriter.setInsertionPoint(loadOp); - - // Build constraint for reading from specific VGPR(s) - std::string constraints = - getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true); - - // Simple v_mov to read from VGPR (compiler will optimize this away) - std::string asmStr = - "; reg_load " + getVGPRRange(vgprOffset, vgprNum, vgprCount); - - Type resultType = loadOp.getResult().getType(); - Type asmType = resultType; - unsigned bitWidth = getBitwidth(resultType); - if (bitWidth < 32) - asmType = rewriter.getIntegerType(32); - - ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); - - Value asmResult = createInlineAsm(rewriter, loc, asmType, {}, asmStr, - constraints, /*hasSideEffects=*/false) - .getResult(0); - - if (bitWidth < 32) { - auto narrowType = rewriter.getIntegerType(bitWidth); - asmResult = arith::TruncIOp::create(rewriter, loc, narrowType, asmResult); - asmResult = LLVM::BitcastOp::create(rewriter, loc, resultType, asmResult); - } - - rewriter.replaceOp(loadOp, asmResult); - return success(); -} - -/// Lower store to register space to inline assembly -template -static LogicalResult lowerStoreToReg(StoreOpTy storeOp, IRRewriter &rewriter, - unsigned vgprOffset) { - Value memref; - if constexpr (std::is_same_v) - memref = storeOp.getBase(); - else - memref = storeOp.getMemRef(); - - // Get destination alloca to find VGPR assignment - auto dstAlloca = memref.getDefiningOp(); - if (!dstAlloca) - return storeOp.emitError("destination must be a memref.alloca"); - - // Get VGPR number from destination alloca - auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); - auto vgprCountAttr = - dstAlloca->getAttrOfType("water.vgpr_count"); - if (!vgprNumAttr || !vgprCountAttr) - return storeOp.emitError("destination alloca missing VGPR attributes"); - - unsigned vgprNum = vgprNumAttr.getInt(); - unsigned vgprCount = vgprCountAttr.getInt(); - - Location loc = storeOp.getLoc(); - rewriter.setInsertionPoint(storeOp); - - // Build constraint for writing to specific VGPR(s) - std::string constraints = - getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",0"; - - // v_mov to write to VGPR (input constraint 0 ties to output) - std::string asmStr = - "; reg_store " + getVGPRRange(vgprOffset, vgprNum, vgprCount); - - Value valueToStore = storeOp.getValueToStore(); - unsigned bitWidth = getBitwidth(valueToStore.getType()); - if (bitWidth < 32) { - auto intType = rewriter.getIntegerType(bitWidth); - valueToStore = - LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); - auto i32Type = rewriter.getIntegerType(32); - valueToStore = arith::ExtUIOp::create(rewriter, loc, i32Type, valueToStore); - } - - createInlineAsm(rewriter, loc, valueToStore.getType(), valueToStore, asmStr, - constraints, - /*hasSideEffects=*/true); - - rewriter.eraseOp(storeOp); - return success(); -} - -class WaterLowerMemoryOpsPass - : public water::impl::WaterLowerMemoryOpsBase { -public: - using Base::Base; - - void runOnOperation() override { - auto func = getOperation(); - auto chip = amdgpu::Chipset::parse(chipset); - if (failed(chip)) { - func->emitError("invalid chipset: ") << chipset; - return signalPassFailure(); - } - - MLIRContext *ctx = &getContext(); - - unsigned totalVGPRs = - chip->majorVersion >= 12 && chip->minorVersion >= 5 ? 1024 : 256; - - // Check if function has VGPR allocation and insert inline asm directive. - auto vgprAttr = func->getAttrOfType("water.total_vgprs"); - unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; - unsigned vgprStart = totalVGPRs - vgprCount; - - if (vgprCount > 0) { - // Add amdgpu-num-vgpr to passthrough attribute list - auto vgprStartAttr = StringAttr::get(ctx, std::to_string(vgprStart)); - auto nameAttr = StringAttr::get(ctx, "amdgpu-num-vgpr"); - - Attribute passthroughAttr; - // Get existing passthrough or create new one - if (auto existingPassthrough = - func->getAttrOfType("passthrough")) { - SmallVector attrs(existingPassthrough.begin(), - existingPassthrough.end()); - attrs.push_back(ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})); - passthroughAttr = ArrayAttr::get(ctx, attrs); - } else { - passthroughAttr = ArrayAttr::get( - ctx, {ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})}); - } - func->setAttr("passthrough", passthroughAttr); - } - - // Insert inline assembly at the beginning of the function. - Block &entryBlock = func.getFunctionBody().front(); - IRRewriter rewriter(ctx); - rewriter.setInsertionPointToStart(&entryBlock); - - if (vgprCount > 0) { - std::string asmStr = "; vgprCount = " + std::to_string(vgprCount) + - " vgprStart = " + std::to_string(vgprStart); - - createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, - /*operands=*/{}, asmStr, /*constraints=*/"", - /*hasSideEffects=*/true); - } - - // Determine if we're targeting RDNA vs CDNA architecture, CDNA has - // different buffer ops format. - bool isRDNAArch = isRDNA(*chip); - - // Helper to dispatch to the appropriate lowering function based on address - // space - auto lowerMemoryOp = [&](Value base, auto lowerRegister, auto lowerBuffer, - auto lowerWorkgroup, - auto lowerGlobal) -> LogicalResult { - if (usesRegisterSpace(base)) - return lowerRegister(); - if (usesBufferAddressSpace(base)) - return lowerBuffer(); - if (usesWorkgroupAddressSpace(base)) - return lowerWorkgroup(); - return lowerGlobal(); - }; - - auto walkFn = [&](Operation *op) { - if (auto loadOp = dyn_cast(op)) { - LogicalResult result = lowerMemoryOp( - loadOp.getBase(), - [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, - [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); - if (failed(result)) - return WalkResult::interrupt(); - return WalkResult::advance(); - } - if (auto storeOp = dyn_cast(op)) { - LogicalResult result = lowerMemoryOp( - storeOp.getBase(), - [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, - [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); - if (failed(result)) - return WalkResult::interrupt(); - return WalkResult::advance(); - } - if (auto loadOp = dyn_cast(op)) { - LogicalResult result = lowerMemoryOp( - loadOp.getMemRef(), - [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, - [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); - if (failed(result)) - return WalkResult::interrupt(); - return WalkResult::advance(); - } - if (auto storeOp = dyn_cast(op)) { - LogicalResult result = lowerMemoryOp( - storeOp.getMemRef(), - [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, - [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); - if (failed(result)) - return WalkResult::interrupt(); - return WalkResult::advance(); - } - if (auto copyOp = dyn_cast(op)) { - // Only lower copy if destination is in register space - if (usesRegisterSpace(copyOp.getTarget())) { - if (failed(lowerCopyToReg(copyOp, rewriter, isRDNAArch, vgprStart))) - return WalkResult::interrupt(); - return WalkResult::advance(); - } - } - return WalkResult::advance(); - }; - - if (func.walk(walkFn).wasInterrupted()) - signalPassFailure(); - - // Clean up register space allocas - they should all be lowered by now - WalkResult cleanupResult = func.walk([&](memref::AllocaOp allocaOp) { - if (usesRegisterSpace(allocaOp.getMemref())) { - if (!allocaOp->use_empty()) { - allocaOp->emitError("register space alloca still has uses after " - "lowering - not all operations were lowered"); - return WalkResult::interrupt(); - } - rewriter.eraseOp(allocaOp); - } - return WalkResult::advance(); - }); - - if (cleanupResult.wasInterrupted()) - signalPassFailure(); - } -}; - -} // namespace diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp deleted file mode 100644 index 505fd31a05..0000000000 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2025 The Wave Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "water/Transforms/Passes.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; - -namespace mlir::water { -#define GEN_PASS_DEF_WATERMATERIALIZEREGCOPY -#include "water/Transforms/Passes.h.inc" -} // namespace mlir::water - -namespace { - -/// Check if a memref type is in virtual register space (memspace 128). -static bool isInRegisterSpace(MemRefType memrefType) { - if (auto memSpace = - dyn_cast_or_null(memrefType.getMemorySpace())) - return memSpace.getInt() == 128; - return false; -} - -static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, - unsigned rank) { - return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; -} - -static void createLoads(IRRewriter &rewriter, Location loc, Value value, - unsigned rank, Value tempAlloca, Operation *op) { - // Group uses by block and find the first use in each block - DenseMap blockToFirstUse; - for (OpOperand &use : value.getUses()) { - Operation *userOp = use.getOwner(); - Block *userBlock = userOp->getBlock(); - auto it = blockToFirstUse.find(userBlock); - if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) - blockToFirstUse[userBlock] = userOp; - } - - SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); - - // Create one load per block, right before the first use in that block - DenseMap blockToLoad; - for (auto &[block, firstUse] : blockToFirstUse) { - rewriter.setInsertionPoint(firstUse); - Value load; - if (isa(op)) - load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); - else if (auto vecLoadOp = dyn_cast(op)) - load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), - tempAlloca, zeroIndices); - blockToLoad[block] = load; - } - - // Replace uses with the appropriate load for their block - for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { - Block *userBlock = use.getOwner()->getBlock(); - use.set(blockToLoad[userBlock]); - } -} - -/// Transform a single load operation to use register space copy. -static LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { - Location loc = op->getLoc(); - rewriter.setInsertionPoint(op); - - // Extract memref, indices, and element type from either load type - Value memref, loadResult; - ValueRange indices; - Type elementType; - SmallVector loadShape; - - if (auto loadOp = dyn_cast(op)) { - memref = loadOp.getMemRef(); - indices = loadOp.getIndices(); - loadResult = loadOp.getResult(); - elementType = loadOp.getType(); - loadShape.resize(indices.size(), 1); - } else if (auto loadOp = dyn_cast(op)) { - memref = loadOp.getBase(); - indices = loadOp.getIndices(); - loadResult = loadOp.getResult(); - VectorType vecType = loadOp.getVectorType(); - elementType = vecType.getElementType(); - loadShape.resize(indices.size() - vecType.getRank(), 1); - llvm::append_range(loadShape, vecType.getShape()); - } else { - return op->emitError("unsupported load operation"); - } - - auto memrefType = cast(memref.getType()); - - // Create subview parameters - Attribute one = rewriter.getIndexAttr(1); - SmallVector offsets, sizes, strides; - for (auto [index, shape] : llvm::zip(indices, loadShape)) { - offsets.push_back(index); - sizes.push_back(rewriter.getIndexAttr(shape)); - strides.push_back(one); - } - - // Create subview of size [1, 1, ..., 1] at the load indices - auto subviewType = - memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); - auto subviewMemRefType = cast(subviewType); - Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, - memref, offsets, sizes, strides); - - // Create temporary buffer in virtual register space (memspace 128) - auto regMemSpace = rewriter.getI32IntegerAttr(128); - auto tempType = - MemRefType::get(subviewMemRefType.getShape(), elementType, - /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); - Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, - /*dynamicSizes=*/ValueRange{}, - /*alignment=*/IntegerAttr()); - - // Copy from subview to temp register buffer - memref::CopyOp::create(rewriter, loc, subview, tempAlloca); - - createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); - - // Erase the original load - rewriter.eraseOp(op); - return success(); -} - -/// Hoist allocas from loops when their loads are yielded. -static void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { - auto yieldedValues = loop.getYieldedValuesMutable(); - if (!yieldedValues) - return; - - auto loopResults = loop.getLoopResults(); - if (!loopResults) - return; - - auto loopInits = loop.getInitsMutable(); - - Block *body = loop.getBody(); - Location loc = loop.getLoc(); - - DominanceInfo dom; - - // Find yielded values that come from loads of memspace 128 allocas - for (auto [idx, yieldedValue, iterArg, init, result] : llvm::enumerate( - *yieldedValues, loop.getRegionIterArgs(), loopInits, *loopResults)) { - // Check if this is a load from memspace 128 - Operation *defOp = yieldedValue.get().getDefiningOp(); - if (!defOp) - continue; - - Value alloca; - ValueRange loadIndices; - if (auto loadOp = dyn_cast(defOp)) { - alloca = loadOp.getMemRef(); - loadIndices = loadOp.getIndices(); - } else if (auto loadOp = dyn_cast(defOp)) { - alloca = loadOp.getBase(); - loadIndices = loadOp.getIndices(); - } else { - continue; - } - - // Check all indices are zero - if (llvm::any_of(loadIndices, - [](Value idx) { return getConstantIntValue(idx) != 0; })) - continue; - - // Check if loading from memspace 128 alloca defined in this loop - auto allocaOp = alloca.getDefiningOp(); - if (!allocaOp) - continue; - if (!isInRegisterSpace(cast(alloca.getType()))) - continue; - if (!body->findAncestorOpInBlock(*allocaOp)) - continue; - - // If load dominates any use of the iter arg, we can't hoist the alloca - // because the load would be invalidated by the store. - bool dominates = false; - for (Operation *user : iterArg.getUsers()) { - if (dom.dominates(defOp, user)) { - dominates = true; - break; - } - } - if (dominates) - continue; - - // Hoist the alloca before the loop - allocaOp->moveBefore(loop); - rewriter.setInsertionPointAfter(allocaOp); - - SmallVector zeroIndices = - getZeroIndices(rewriter, loc, loadIndices.size()); - - // Store the iter arg into the alloca - if (isa(defOp)) { - memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); - } else if (auto vectorLoad = dyn_cast(defOp)) { - vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); - } - - // Create iter arg loads - createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); - - // Create a load after the loop - rewriter.setInsertionPointAfter(loop); - zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); - Value loadAfterLoop; - if (isa(defOp)) { - loadAfterLoop = - memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); - } else if (auto vectorLoad = dyn_cast(defOp)) { - loadAfterLoop = vector::LoadOp::create( - rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); - } - - // Replace uses of the loop result with the new load - result.replaceAllUsesWith(loadAfterLoop); - } -} - -/// Materialize register copies by routing memref.load through temporary -/// buffers in virtual register space (memspace 128). -class WaterMaterializeRegCopyPass - : public water::impl::WaterMaterializeRegCopyBase< - WaterMaterializeRegCopyPass> { -public: - void runOnOperation() override { - IRRewriter rewriter(&getContext()); - - // Collect all load operations to transform - SmallVector loadsToTransform; - getOperation()->walk([&](Operation *op) { - if (auto loadOp = dyn_cast(op)) { - if (!isInRegisterSpace(cast(loadOp.getMemRef().getType()))) - loadsToTransform.push_back(op); - } else if (auto loadOp = dyn_cast(op)) { - if (!isInRegisterSpace(cast(loadOp.getBase().getType()))) - loadsToTransform.push_back(op); - } - }); - - for (Operation *op : loadsToTransform) { - if (failed(materializeRegCopy(rewriter, op))) - return signalPassFailure(); - } - - // Hoist allocas out of loops when their loads are yielded - getOperation()->walk( - [&](scf::ForOp forOp) { hoistAllocasFromLoop(rewriter, forOp); }); - } -}; - -} // namespace diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp deleted file mode 100644 index 063ed7cf41..0000000000 --- a/water/lib/Transforms/WaterNumberRegisters.cpp +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2025 The Wave Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "water/Transforms/Passes.h" - -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Builders.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; - -namespace mlir::water { -#define GEN_PASS_DEF_WATERNUMBERREGISTERS -#include "water/Transforms/Passes.h.inc" -} // namespace mlir::water - -namespace { - -/// Check if a memref type is in virtual register space (memspace 128). -static bool isInRegisterSpace(MemRefType memrefType) { - if (auto memSpace = - dyn_cast_or_null(memrefType.getMemorySpace())) - return memSpace.getInt() == 128; - return false; -} - -/// Calculate the number of 32-bit registers needed for a memref type. -static FailureOr getRegisterCount(MemRefType memrefType) { - // Calculate total size in bytes - unsigned elementSizeBytes = memrefType.getElementTypeBitWidth() / 8; - unsigned numElements = 1; - for (int64_t dim : memrefType.getShape()) { - if (dim == ShapedType::kDynamic) - return failure(); // Can't allocate dynamic sizes in registers. - - numElements *= dim; - } - - unsigned totalBytes = elementSizeBytes * numElements; - - // Each register is 32 bits = 4 bytes - // Round up to next register boundary. - return (totalBytes + 3) / 4; -} - -/// Assign physical registers to register space allocas. -class WaterNumberRegistersPass - : public water::impl::WaterNumberRegistersBase { -public: - void runOnOperation() override { - auto func = getOperation(); - MLIRContext *ctx = &getContext(); - - SmallVector> regCounts; - - Type i32 = IntegerType::get(ctx, 32); - WalkResult result = func->walk([&](memref::AllocaOp allocaOp) { - auto memrefType = allocaOp.getType(); - if (!isInRegisterSpace(memrefType)) - return WalkResult::advance(); - - auto regCount = getRegisterCount(memrefType); - if (failed(regCount)) { - allocaOp->emitError( - "Cannot allocate dynamic-sized memref in register space"); - return WalkResult::interrupt(); - } - - regCounts.emplace_back(*regCount, allocaOp); - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) - return signalPassFailure(); - - // Sort by register size to reduce register alignment gaps. - llvm::stable_sort(regCounts, [](const std::pair &a, - const std::pair &b) { - return a.first < b.first; - }); - - // TODO: for now, just assign registers sequentially. In the future, - // we need a liveness analysis to assign registers. - unsigned nextRegister = 0; - - for (auto [regCount, op] : regCounts) { - // Align to regCount boundary. - nextRegister = ((nextRegister + regCount - 1) / regCount) * regCount; - - // Assign starting register number. - op->setAttr("water.vgpr_number", IntegerAttr::get(i32, nextRegister)); - - // Track how many registers this alloca uses. - op->setAttr("water.vgpr_count", IntegerAttr::get(i32, regCount)); - - // Advance to next available register. - nextRegister += regCount; - } - - // Attach metadata to function with total register count. - func->setAttr("water.total_vgprs", IntegerAttr::get(i32, nextRegister)); - } -}; - -} // namespace diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir deleted file mode 100644 index 0458bc81ea..0000000000 --- a/water/test/Transforms/lower-memory-ops.mlir +++ /dev/null @@ -1,557 +0,0 @@ -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s --check-prefixes=CHECK,GFX9 -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx1200}))' | FileCheck %s --check-prefixes=CHECK,GFX12 - -// Test lowering of vector memory operations to AMDGPU global_load/store inline assembly - -// CHECK-LABEL: func.func @simple_function -func.func @simple_function(%arg0: f32) -> f32 { - // CHECK: return %arg0 - return %arg0 : f32 -} - -// CHECK-LABEL: func.func @vector_load -func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // CHECK: memref.extract_aligned_pointer_as_index - // CHECK: arith.index_cast - // CHECK: llvm.inttoptr - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: return - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @vector_store -func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { - // CHECK: memref.extract_aligned_pointer_as_index - // CHECK: arith.index_cast - // CHECK: llvm.inttoptr - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: return - return -} - -// CHECK-LABEL: func.func @vector_load_b32 -func.func @vector_load_b32(%memref: memref<1024xf32>, %offset: index) -> vector<1xf32> { - // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<1xf32> - return %result : vector<1xf32> -} - -// CHECK-LABEL: func.func @vector_load_b64 -func.func @vector_load_b64(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> - return %result : vector<2xf32> -} - -// CHECK-LABEL: func.func @vector_load_b96 -func.func @vector_load_b96(%memref: memref<1024xf32>, %offset: index) -> vector<3xf32> { - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx3 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<3xf32> - return %result : vector<3xf32> -} - -// CHECK-LABEL: func.func @vector_load_b128 -func.func @vector_load_b128(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @vector_store_b32 -func.func @vector_store_b32(%memref: memref<1024xf32>, %offset: index, %data: vector<1xf32>) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<1xf32> - return -} - -// CHECK-LABEL: func.func @vector_store_b64 -func.func @vector_store_b64(%memref: memref<1024xf32>, %offset: index, %data: vector<2xf32>) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<2xf32> - return -} - -// CHECK-LABEL: func.func @vector_store_b96 -func.func @vector_store_b96(%memref: memref<1024xf32>, %offset: index, %data: vector<3xf32>) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx3 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<3xf32> - return -} - -// CHECK-LABEL: func.func @vector_store_b128 -func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - return -} - -// CHECK-LABEL: func.func @load_store_sequence -func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { - // Test lowering of load/store sequence - - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> - - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: return - return -} - -// ----- -// Buffer operations tests - -// CHECK-LABEL: func.func @buffer_load_b32 -func.func @buffer_load_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<1xf32> { - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" - %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> - return %result : vector<1xf32> -} - -// CHECK-LABEL: func.func @buffer_load_b64 -func.func @buffer_load_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<2xf32> { - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b64 $0, $1, $2, 0 offen", "=v,v,s" - %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> - return %result : vector<2xf32> -} - -// CHECK-LABEL: func.func @buffer_load_b96 -func.func @buffer_load_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<3xf32> { - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b96 $0, $1, $2, 0 offen", "=v,v,s" - %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> - return %result : vector<3xf32> -} - -// CHECK-LABEL: func.func @buffer_load_b128 -func.func @buffer_load_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<4xf32> { - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" - %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @buffer_store_b32 -func.func @buffer_store_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<1xf32>) { - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" - vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> - return -} - -// CHECK-LABEL: func.func @buffer_store_b64 -func.func @buffer_store_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<2xf32>) { - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b64 $0, $1, $2, 0 offen", "v,v,s" - vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> - return -} - -// CHECK-LABEL: func.func @buffer_store_b96 -func.func @buffer_store_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<3xf32>) { - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b96 $0, $1, $2, 0 offen", "v,v,s" - vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> - return -} - -// CHECK-LABEL: func.func @buffer_store_b128 -func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<4xf32>) { - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" - vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> - return -} - -// CHECK-LABEL: func.func @mixed_global_and_buffer -func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) { - // Load from global memory (should use global_load) - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to buffer memory (should use buffer_store) - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" - vector.store %global_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> - - // Load from buffer memory (should use buffer_load) - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" - %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> - - // Store to global memory (should use global_store) - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %buffer_data, %global[%offset] : memref<1024xf32>, vector<4xf32> - - return -} -// ----- -// DS operations tests - -// CHECK-LABEL: func.func @ds_load_b32 -func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<1xf32> { - // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" - %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> - return %result : vector<1xf32> -} - -// CHECK-LABEL: func.func @ds_load_b64 -func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<2xf32> { - // CHECK: llvm.inline_asm has_side_effects "ds_read_b64 $0, $1", "=v,v" - %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> - return %result : vector<2xf32> -} - -// CHECK-LABEL: func.func @ds_load_b96 -func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<3xf32> { - // CHECK: llvm.inline_asm has_side_effects "ds_read_b96 $0, $1", "=v,v" - %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> - return %result : vector<3xf32> -} - -// CHECK-LABEL: func.func @ds_load_b128 -func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<4xf32> { - // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" - %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @ds_store_b32 -func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<1xf32>) { - // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> - return -} - -// CHECK-LABEL: func.func @ds_store_b64 -func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<2xf32>) { - // CHECK: llvm.inline_asm has_side_effects "ds_write_b64 $0, $1", "v,v" - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> - return -} - -// CHECK-LABEL: func.func @ds_store_b96 -func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<3xf32>) { - // CHECK: llvm.inline_asm has_side_effects "ds_write_b96 $0, $1", "v,v" - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> - return -} - -// CHECK-LABEL: func.func @ds_store_b128 -func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<4xf32>) { - // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - return -} - -// CHECK-LABEL: func.func @mixed_global_buffer_and_ds -func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %lds: memref<1024xf32, #gpu.address_space>, %offset: index) { - // Load from global (should use global_load) - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to LDS (should use ds_write) - // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" - vector.store %global_data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Load from LDS (should use ds_read) - // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" - %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Store to buffer (should use buffer_store) - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" - vector.store %lds_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> - - return -} - -// ----- -// Scalar (memref) operations tests - -// CHECK-LABEL: func.func @scalar_load_global_f32 -func.func @scalar_load_global_f32(%memref: memref<1024xf32>, %offset: index) -> f32 { - // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" - %result = memref.load %memref[%offset] : memref<1024xf32> - return %result : f32 -} - -// CHECK-LABEL: func.func @scalar_load_global_f64 -func.func @scalar_load_global_f64(%memref: memref<1024xf64>, %offset: index) -> f64 { - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" - %result = memref.load %memref[%offset] : memref<1024xf64> - return %result : f64 -} - -// CHECK-LABEL: func.func @scalar_store_global_f32 -func.func @scalar_store_global_f32(%memref: memref<1024xf32>, %offset: index, %data: f32) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" - memref.store %data, %memref[%offset] : memref<1024xf32> - return -} - -// CHECK-LABEL: func.func @scalar_store_global_f64 -func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %data: f64) { - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" - memref.store %data, %memref[%offset] : memref<1024xf64> - return -} - -// CHECK-LABEL: func.func @scalar_load_buffer_f32 -func.func @scalar_load_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> f32 { - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" - %result = memref.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> - return %result : f32 -} - -// CHECK-LABEL: func.func @scalar_store_buffer_f32 -func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: f32) { - // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" - memref.store %data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> - return -} - -// CHECK-LABEL: func.func @scalar_load_ds_f32 -func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> f32 { - // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" - %result = memref.load %lds[%offset] : memref<1024xf32, #gpu.address_space> - return %result : f32 -} - -// CHECK-LABEL: func.func @scalar_store_ds_f32 -func.func @scalar_store_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: f32) { - // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" - memref.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space> - return -} - -// CHECK-LABEL: func.func @mixed_scalar_and_vector -func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { - // Scalar load - // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" - %scalar = memref.load %memref[%offset] : memref<1024xf32> - - // Vector load - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" - %vector = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Scalar store - // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" - memref.store %scalar, %memref[%offset] : memref<1024xf32> - - // Vector store - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %vector, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - return -} - -// Test copy to register space with pre-numbered allocas - -// CHECK-LABEL: func.func @copy_global_to_reg_scalar -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] -func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { - %c0 = arith.constant 0 : index - %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> - %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" - memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" - // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" - %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> - // CHECK-NOT: memref.alloca - return %val : f32 -} - -// CHECK-LABEL: func.func @copy_global_to_reg_vector -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { - %c0 = arith.constant 0 : index - %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" - memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK-NOT: memref.alloca - return %val : vector<4xf32> -} - -// CHECK-LABEL: func.func @copy_buffer_to_reg -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { - %c0 = arith.constant 0 : index - %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> - // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[252:255]},v,s" - memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK-NOT: memref.alloca - return %val : vector<4xf32> -} - -// CHECK-LABEL: func.func @copy_workgroup_to_reg -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { - %c0 = arith.constant 0 : index - %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> - // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK-NOT: memref.alloca - return %val : vector<4xf32> -} - -// CHECK-LABEL: func.func @store_to_reg -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] -func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i32} { - %c0 = arith.constant 0 : index - %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" - // GFX12: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" - memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" - // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" - %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> - // CHECK-NOT: memref.alloca - return %result : f32 -} - -// CHECK-LABEL: func.func @multiple_reg_allocas -// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] -func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { - %c0 = arith.constant 0 : index - %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> - %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v247},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" - %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[248:251]},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" - %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> - memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> - memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v247", "={v247}" - // GFX12: llvm.inline_asm "; reg_load v247", "={v247}" - %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" - // GFX12: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" - %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" - %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK-NOT: memref.alloca - return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> -} - -// ----- -// Test MFMA hazard handling with s_nop insertion - -// CHECK-LABEL: func.func @mfma_hazard_store -func.func @mfma_hazard_store(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { - %offset = arith.constant 0 : index - - // Perform MFMA operation - %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - - // Store MFMA result - should trigger hazard handling - // CHECK: rocdl.sched.barrier - // CHECK: arith.constant 4 : i16 - // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> - - return -} - -// CHECK-LABEL: func.func @mfma_hazard_with_extract -func.func @mfma_hazard_with_extract(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { - %offset = arith.constant 0 : index - - // MFMA with vector extract - hazard checking should propagate through extract - %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - - %extracted = vector.extract %result[0] : f32 from vector<4xf32> - - // Store extracted value - should still detect hazard through propagation - // CHECK: rocdl.sched.barrier - // CHECK: arith.constant 4 : i16 - // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" - // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" - memref.store %extracted, %arg0[%offset] : memref<1024xf32> - - return -} - -// CHECK-LABEL: func.func @no_hazard_with_existing_nop -func.func @no_hazard_with_existing_nop(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { - %offset = arith.constant 0 : index - - %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - - // Manually insert s.nop - %nop_count = arith.constant 4 : i16 - llvm.call_intrinsic "llvm.amdgcn.s.nop"(%nop_count) : (i16) -> () - - // Store should NOT insert another s.nop since one already exists - // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" - // CHECK-NOT: rocdl.sched.barrier - // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" - // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" - vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> - - return -} diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir deleted file mode 100644 index 7f789e8c9e..0000000000 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ /dev/null @@ -1,165 +0,0 @@ -// RUN: water-opt %s --water-materialize-reg-copy | FileCheck %s - -// CHECK-LABEL: func @test_simple_load -func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f32 { - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] - // CHECK-SAME: memref<10x20xf32> to memref<1x1xf32, strided<[20, 1], offset: ?>> - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1xf32, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]]] - // CHECK: return %[[RESULT]] - %0 = memref.load %arg0[%i, %j] : memref<10x20xf32> - return %0 : f32 -} - -// CHECK-LABEL: func @test_simple_vector_load -func.func @test_simple_vector_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> vector<4xf32> { - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 4] [1, 1] - // CHECK-SAME: memref<10x20xf32> to memref<1x4xf32, strided<[20, 1], offset: ?>> - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x4xf32, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0]]] - // CHECK: return %[[RESULT]] - %0 = vector.load %arg0[%i, %j] : memref<10x20xf32>, vector<4xf32> - return %0 : vector<4xf32> -} - -// CHECK-LABEL: func @test_1d_load -func.func @test_1d_load(%arg0: memref<100xf16>, %i: index) -> f16 { - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1] [1] [1] - // CHECK-SAME: memref<100xf16> to memref<1xf16, strided<[1], offset: ?>> - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf16, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]]] - // CHECK: return %[[RESULT]] - %0 = memref.load %arg0[%i] : memref<100xf16> - return %0 : f16 -} - -// CHECK-LABEL: func @test_3d_load -func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: index) -> i32 { - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] - // CHECK-SAME: memref<8x16x32xi32> to memref<1x1x1xi32, strided<[512, 32, 1], offset: ?>> - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1x1xi32, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]], %[[C0]]] - // CHECK: return %[[RESULT]] - %0 = memref.load %arg0[%i, %j, %k] : memref<8x16x32xi32> - return %0 : i32 -} - -// CHECK-LABEL: func @test_multiple_loads -func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) -> f32 { - // First load: subview, alloca, copy - // CHECK: memref.subview - // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> - // CHECK: memref.copy - %0 = memref.load %arg0[%i, %j] : memref<10x10xf32> - - // Second load: subview, alloca, copy - // CHECK: memref.subview - // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> - // CHECK: memref.copy - %1 = memref.load %arg0[%j, %i] : memref<10x10xf32> - - // Now the actual loads happen right before the addf (late as possible) - // CHECK: memref.load - // CHECK: memref.load - // CHECK: arith.addf - %2 = arith.addf %0, %1 : f32 - return %2 : f32 -} - -// CHECK-LABEL: func @test_skip_memspace_128 -func.func @test_skip_memspace_128(%arg0: memref<10xf32>, %arg1: memref<5xf32, 128 : i32>, %i: index) -> f32 { - // This load should be transformed (from default memspace) - // First: subview, alloca, copy - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - %0 = memref.load %arg0[%i] : memref<10xf32> - - // This load should NOT be transformed (already from memspace 128) - // It stays in place - // CHECK: %[[VAL1:.*]] = memref.load %arg1[%arg2] : memref<5xf32, 128 : i32> - %1 = memref.load %arg1[%i] : memref<5xf32, 128 : i32> - - // The load from temp happens late (right before addf) - // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] - // Note: operands may be reordered - // CHECK: arith.addf - %result = arith.addf %0, %1 : f32 - // CHECK: return - return %result : f32 -} - -// CHECK-LABEL: func @test_control_flow -func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 { - // Load happens once, but value is used in multiple blocks - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] - // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> - // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] - // CHECK: %[[C0:.*]] = arith.constant 0 : index - %val = memref.load %arg0[%i] : memref<10xf32> - - // CHECK: cf.cond_br - cf.cond_br %cond, ^bb1, ^bb2 - -^bb1: - // First block: load happens here before the addf - // CHECK: ^bb1: - // CHECK: %[[CONST1:.*]] = arith.constant 1.0 - // CHECK: %[[LOAD1:.*]] = memref.load %[[TEMP]][%[[C0]]] - // CHECK: %[[ADD1:.*]] = arith.addf %[[LOAD1]], %[[CONST1]] - %c1 = arith.constant 1.0 : f32 - %sum1 = arith.addf %val, %c1 : f32 - // CHECK: cf.br ^bb3(%[[ADD1]] - cf.br ^bb3(%sum1 : f32) - -^bb2: - // Second block: another load happens here before the mulf - // CHECK: ^bb2: - // CHECK: %[[CONST2:.*]] = arith.constant 2.0 - // CHECK: %[[LOAD2:.*]] = memref.load %[[TEMP]][%[[C0]]] - // CHECK: %[[MUL:.*]] = arith.mulf %[[LOAD2]], %[[CONST2]] - %c2 = arith.constant 2.0 : f32 - %prod = arith.mulf %val, %c2 : f32 - // CHECK: cf.br ^bb3(%[[MUL]] - cf.br ^bb3(%prod : f32) - -^bb3(%result: f32): - // CHECK: ^bb3(%[[RESULT:.*]]: f32): - // CHECK: return %[[RESULT]] - return %result : f32 -} - -// CHECK-LABEL: func @test_loop_hoist -func.func @test_loop_hoist(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index, %init: f32) -> f32 { - %c0 = arith.constant 0 : index - // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<1xf32, 128 : i32> - // CHECK: arith.constant 0 : index - // CHECK: memref.store %arg4, %[[ALLOCA]] - // CHECK: scf.for %[[IV:.*]] = %arg1 to %arg2 step %arg3 iter_args(%[[ITER_ARG:.*]] = %arg4) - %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (f32) { - // CHECK: memref.load %[[ALLOCA]] - // CHECK: memref.store %{{.*}}, %arg0[%c0] - memref.store %arg, %arg0[%c0] : memref<100xf32> - %alloca = memref.alloca() : memref<1xf32, 128 : i32> - %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - memref.copy %subview, %alloca : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - %val = memref.load %alloca[%c0] : memref<1xf32, 128 : i32> - // CHECK: memref.subview - // CHECK: memref.copy - // CHECK: memref.load %[[ALLOCA]] - // CHECK: scf.yield - scf.yield %val : f32 - } - // CHECK: memref.load %[[ALLOCA]] - // CHECK: return - return %result : f32 -} diff --git a/water/test/Transforms/number-registers-error.mlir b/water/test/Transforms/number-registers-error.mlir deleted file mode 100644 index 4d9f77435f..0000000000 --- a/water/test/Transforms/number-registers-error.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' --verify-diagnostics - -func.func @test_dynamic_size_error(%n: index) { - // expected-error @+1 {{Cannot allocate dynamic-sized memref in register space}} - %reg = memref.alloca(%n) : memref - return -} diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir deleted file mode 100644 index ccc32447bb..0000000000 --- a/water/test/Transforms/number-registers.mlir +++ /dev/null @@ -1,80 +0,0 @@ -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s - -// CHECK-LABEL: func @test_simple_numbering -// CHECK-SAME: attributes {water.total_vgprs = 8 : i32} -func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { - %c0 = arith.constant 0 : index - - // 1xf32 = 4 bytes = 1 register, starts at reg 0 - // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} - %reg0 = memref.alloca() : memref<1xf32, 128 : i32> - - // 4xf32 = 16 bytes = 4 registers, starts at reg 4 - // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} - %reg1 = memref.alloca() : memref<4xf32, 128 : i32> - - // 1xf32 = 4 bytes = 1 register, starts at reg 1 (after reg0) - // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 1 : i32} - %reg2 = memref.alloca() : memref<1xf32, 128 : i32> - - %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - memref.copy %subview0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - - %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> - - return %val0 : f32 -} - -// CHECK-LABEL: func @test_loop_with_registers -// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} -func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index) { - %c0 = arith.constant 0 : index - - // Register allocated outside loop - // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} - %reg = memref.alloca() : memref<1xf32, 128 : i32> - - scf.for %iv = %lb to %ub step %step { - %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> - memref.store %val, %arg0[%iv] : memref<100xf32> - } - - return -} - -// CHECK-LABEL: func @test_triple_buffering_numbering -// CHECK-SAME: attributes {water.total_vgprs = 12 : i32} -func.func @test_triple_buffering_numbering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { - %c0 = arith.constant 0 : index - - // Three registers for triple buffering, each 4xf32 = 4 registers - // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 0 : i32} - %reg0 = memref.alloca() : memref<4xf32, 128 : i32> - - // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} - %reg1 = memref.alloca() : memref<4xf32, 128 : i32> - - // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 8 : i32} - %reg2 = memref.alloca() : memref<4xf32, 128 : i32> - - return -} - -// CHECK-LABEL: func @test_mixed_memspaces -// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} -func.func @test_mixed_memspaces(%arg0: memref<100xf32>) { - %c0 = arith.constant 0 : index - - // Non-register space alloca - should not be numbered - // CHECK: memref.alloca() : memref<10xf32> - // CHECK-NOT: water.vgpr_number - %local = memref.alloca() : memref<10xf32> - - // Register space alloca - should be numbered - // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} - %reg = memref.alloca() : memref<1xf32, 128 : i32> - - return -} From ea23eb6be895258a606679b5fd5bd3bc88d0badc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 21:29:39 +0100 Subject: [PATCH 03/38] del ssa dep Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 55 --------------------- 1 file changed, 55 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 8e9eb5fbb9..dceb970e78 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -464,54 +464,6 @@ class WaitcntState : public AbstractDenseLattice { /// Check if there's a waitcnt requirement bool hasRequirement() const { return requirement.hasRequirement(); } - /// Check if a value depends on pending operations and compute required wait - WaitcntRequirement - checkSSADependency(Value val, - llvm::SmallSetVector &barriers) const { - // Check if val is produced by any pending operation - Operation *defOp = val.getDefiningOp(); - if (!defOp) - return {}; - - if (!isPendingOp(defOp)) - return {}; - - WaitcntRequirement result; - for (auto &pendingOps : pendingOpsLists) { - if (pendingOps->empty()) - continue; - - Operation *barrier = nullptr; - - // Search from the back to find the most recent dependency - bool found = false; - auto req = WaitcntRequirement::getOperationRequirement(defOp, true); - for (Operation *op : llvm::reverse(pendingOps->ops)) { - if (op == defOp) { - found = true; - break; - } - - if (!barrier && isBarrier(op)) - barrier = op; - - auto opReq = WaitcntRequirement::getOperationRequirement(op, false); - if (!req.isSameCounterType(opReq)) - continue; - - req = req + opReq; - } - - if (found) { - result.merge(req); - if (barrier) - barriers.insert(barrier); - } - } - - return result; - } - /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait WaitcntRequirement checkMemoryDependency(Operation *op, @@ -691,14 +643,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { llvm::SmallSetVector barriers; - // Check if any operands depend on pending operations (value dependency) WaitcntRequirement opRequirement = after->getRequirement(); - for (Value operand : op->getOperands()) { - if (auto req = before.checkSSADependency(operand, barriers)) { - // Merge this requirement (take minimum for conservative wait) - opRequirement.merge(req); - } - } // Check for memory dependencies (RAW, WAR, WAW) if (auto memReq = before.checkMemoryDependency(op, barriers)) { From e8a23942ff743b1dd5522a05d9a81cfd173c493d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 21:50:54 +0100 Subject: [PATCH 04/38] tensor lod Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 35 ++++++++++----------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index dceb970e78..f6dfbff580 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -68,32 +68,31 @@ static std::optional propagateViewOps(Value value) { return value; } +static std::optional propagateTensorDesc(Value value, bool isLoad) { + auto makeDesc = value.getDefiningOp(); + if (!makeDesc) + return value; + + value = makeDesc.getBase(); + auto makeBase = value.getDefiningOp(); + if (!makeBase) + return value; + + return propagateViewOps(isLoad ? makeBase.getGlobal() : makeBase.getLds()); +} + /// Check if the operation is a load operation and return the base memref. static std::optional isLoadOp(Operation *op) { - // TODO: replace with the interface when available. - if (auto load = dyn_cast(op)) - return propagateViewOps(load.getBase()); - if (auto load = dyn_cast(op)) - return propagateViewOps(load.getMemRef()); - if (auto copy = dyn_cast(op)) - return propagateViewOps(copy.getSource()); - if (auto gather = dyn_cast(op)) - return propagateViewOps(gather.getSrc()); + if (auto load = dyn_cast(op)) + return propagateTensorDesc(load.getDesc(), true); return std::nullopt; } /// Check if the operation is a store operation and return the base memref. static std::optional isStoreOp(Operation *op) { - // TODO: replace with the interface when available. - if (auto store = dyn_cast(op)) - return propagateViewOps(store.getBase()); - if (auto store = dyn_cast(op)) - return propagateViewOps(store.getMemRef()); - if (auto copy = dyn_cast(op)) - return propagateViewOps(copy.getTarget()); - if (auto gather = dyn_cast(op)) - return propagateViewOps(gather.getDst()); + if (auto store = dyn_cast(op)) + return propagateTensorDesc(store.getDesc(), false); return std::nullopt; } From 48215d93d774531b7305e5e9f8a83f045472870d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 21:58:00 +0100 Subject: [PATCH 05/38] del Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 95 +++++---------------- 1 file changed, 22 insertions(+), 73 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index f6dfbff580..1c9123aff3 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -33,33 +33,6 @@ static bool isBarrier(Operation *op) { return isa(op) || isa(op); } -static bool isRegisterAddressSpace(MemRefType type) { - auto attr = dyn_cast_or_null(type.getMemorySpace()); - return attr && attr.getInt() == 128; -} - -static bool isWorkgroupAddressSpace(MemRefType type) { - auto attr = dyn_cast_or_null(type.getMemorySpace()); - return attr && attr.getValue() == gpu::AddressSpace::Workgroup; -} - -static bool isWorkgroupAddressSpace(std::optional value) { - if (!value) - return false; - - auto memrefType = cast(value->getType()); - return isWorkgroupAddressSpace(memrefType); -} - -static bool isGlobalAddressSpace(std::optional value) { - if (!value) - return false; - - auto memrefType = cast(value->getType()); - return !isWorkgroupAddressSpace(memrefType) && - !isRegisterAddressSpace(memrefType); -} - /// Try to propagate view operations to the base memref. static std::optional propagateViewOps(Value value) { while (auto view = value.getDefiningOp()) @@ -192,21 +165,16 @@ struct PendingOperations { /// Waitcnt requirement for synchronization struct WaitcntRequirement { - std::optional load_cnt; - std::optional ds_cnt; + std::optional tensor_cnt; WaitcntRequirement() = default; WaitcntRequirement(amdgpu::MemoryCounterWaitOp waitOp) { - if (auto loadCnt = waitOp.getLoadAttr()) - load_cnt = loadCnt.getInt(); - if (auto dsCnt = waitOp.getDsAttr()) - ds_cnt = dsCnt.getInt(); + if (auto tensorCnt = waitOp.getTensorAttr()) + tensor_cnt = tensorCnt.getInt(); } - bool hasRequirement() const { - return load_cnt.has_value() || ds_cnt.has_value(); - } + bool hasRequirement() const { return tensor_cnt.has_value(); } /// Merge with another requirement (take minimum for conservative join) /// Returns true if this requirement changed @@ -214,15 +182,9 @@ struct WaitcntRequirement { bool changed = false; // Take minimum of each counter (lower value = more restrictive) - if (other.load_cnt.has_value()) { - if (!load_cnt.has_value() || *other.load_cnt < *load_cnt) { - load_cnt = other.load_cnt; - changed = true; - } - } - if (other.ds_cnt.has_value()) { - if (!ds_cnt.has_value() || *other.ds_cnt < *ds_cnt) { - ds_cnt = other.ds_cnt; + if (other.tensor_cnt.has_value()) { + if (!tensor_cnt.has_value() || *other.tensor_cnt < *tensor_cnt) { + tensor_cnt = other.tensor_cnt; changed = true; } } @@ -230,49 +192,36 @@ struct WaitcntRequirement { return changed; } - std::optional getLoadCnt() const { return load_cnt; } - std::optional getStoreCnt() const { return std::nullopt; } - std::optional getDsCnt() const { return ds_cnt; } + std::optional getTensorCnt() const { return tensor_cnt; } bool isSameCounterType(const WaitcntRequirement &other) const { - return load_cnt.has_value() == other.load_cnt.has_value() || - ds_cnt.has_value() == other.ds_cnt.has_value(); + return tensor_cnt.has_value() == other.tensor_cnt.has_value(); } static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { WaitcntRequirement req; - std::optional loadBase = isLoadOp(op); - std::optional storeBase = isStoreOp(op); - if (isWorkgroupAddressSpace(loadBase) || - isWorkgroupAddressSpace(storeBase)) { - req.ds_cnt = zero ? 0 : 1; - } else if (isGlobalAddressSpace(loadBase) || - isGlobalAddressSpace(storeBase)) { - req.load_cnt = zero ? 0 : 1; - } + if (isa(op)) + req.tensor_cnt = zero ? 0 : 1; + return req; } WaitcntRequirement operator+(const WaitcntRequirement &other) const { WaitcntRequirement result; - if (load_cnt || other.load_cnt) - result.load_cnt = load_cnt.value_or(0) + other.load_cnt.value_or(0); - if (ds_cnt || other.ds_cnt) - result.ds_cnt = ds_cnt.value_or(0) + other.ds_cnt.value_or(0); + if (tensor_cnt || other.tensor_cnt) + result.tensor_cnt = tensor_cnt.value_or(0) + other.tensor_cnt.value_or(0); return result; } bool operator>(const WaitcntRequirement &other) const { - if (load_cnt && other.load_cnt && *load_cnt > *other.load_cnt) - return true; - if (ds_cnt && other.ds_cnt && *ds_cnt > *other.ds_cnt) + if (tensor_cnt && other.tensor_cnt && *tensor_cnt > *other.tensor_cnt) return true; return false; } operator bool() const { return hasRequirement(); } void print(raw_ostream &os) const { - os << "WaitcntRequirement: load_cnt=" << load_cnt << " ds_cnt=" << ds_cnt; + os << "WaitcntRequirement: tensor_cnt=" << tensor_cnt; } }; @@ -283,8 +232,9 @@ inline raw_ostream &operator<<(raw_ostream &os, } static bool mayAlias(Value lhs, Value rhs, ArrayRef tokens) { - if (isWorkgroupAddressSpace(cast(lhs.getType())) != - isWorkgroupAddressSpace(cast(rhs.getType()))) + auto memref1 = cast(lhs.getType()); + auto memref2 = cast(rhs.getType()); + if (memref1.getMemorySpace() != memref2.getMemorySpace()) return false; return llvm::is_contained(tokens, lhs); @@ -803,10 +753,9 @@ class WaterInsertWaitcntPass // If the current operation is already a memory_counter_wait operation // they will be merged later. rewriter.setInsertionPoint(operation); - amdgpu::MemoryCounterWaitOp::create( - rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), - getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr, - nullptr); + amdgpu::MemoryCounterWaitOp::create(rewriter, operation->getLoc(), + nullptr, nullptr, nullptr, nullptr, + getAttr(req.getTensorCnt())); }); } }; From 0eb692a02b97489c555a8060b7e1567f0a26cb4b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 22:08:50 +0100 Subject: [PATCH 06/38] del test Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 623 ---------------------- 1 file changed, 623 deletions(-) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index a8f82fce8f..987743c2a2 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -1,624 +1 @@ // RUN: water-opt %s --water-insert-waitcnt | FileCheck %s - -// CHECK-LABEL: func.func @single_load_use -func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // CHECK: vector.load - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @two_loads_use_in_reverse_order -// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) -func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] - // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] - %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> - %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(1) - // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] - %addA = arith.addf %loadA, %loadA : vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] - %addB = arith.addf %loadB, %addA : vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - - // CHECK: return %[[ADD_B]] - return %addB : vector<4xf32> -} - -// CHECK-LABEL: func.func @lds_barriers -// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) -func.func @lds_barriers(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] - // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] - %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> - %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(1) - // CHECK-NEXT: amdgpu.lds_barrier - // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] - amdgpu.lds_barrier - %addA = arith.addf %loadA, %loadA : vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: amdgpu.lds_barrier - // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] - amdgpu.lds_barrier - %addB = arith.addf %loadB, %addA : vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - - // CHECK: return %[[ADD_B]] - return %addB : vector<4xf32> -} - -// CHECK-LABEL: func.func @raw_dependency -// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) -func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Store to memory - // CHECK: vector.store %[[DATA]], %[[MEM]] - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Load from same memory - RAW dependency, must wait for store - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]] - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @raw_dependency_memref -// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: f32, %{{.*}}: index) -func.func @raw_dependency_memref(%memref: memref<1024xf32>, %data: f32, %offset: index) -> f32 { - // Store to memory - // CHECK: memref.store %[[DATA]], %[[MEM]] - memref.store %data, %memref[%offset] : memref<1024xf32> - - // Load from same memory - RAW dependency, must wait for store - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]] - %result = memref.load %memref[%offset] : memref<1024xf32> - - // CHECK: return %[[LOAD]] - return %result : f32 -} - -// CHECK-LABEL: func.func @war_dependency -// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) -func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Load from memory - // CHECK: %[[LOAD:.*]] = vector.load %[[MEM]] - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to same memory - WAR dependency, must wait for load - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: vector.store %[[DATA]], %[[MEM]] - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @waw_dependency -// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA1:.*]]: vector<4xf32>, %[[DATA2:.*]]: vector<4xf32>, %{{.*}}: index) -func.func @waw_dependency(%memref: memref<1024xf32>, %data1: vector<4xf32>, %data2: vector<4xf32>, %offset: index) { - // First store - // CHECK: vector.store %[[DATA1]], %[[MEM]] - vector.store %data1, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Second store to same memory - WAW dependency, must wait for first store - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: vector.store %[[DATA2]], %[[MEM]] - vector.store %data2, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: return - return -} - -// CHECK-LABEL: func.func @raw_dependency_non_zero_waitcnt -func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Allocate two distinct memrefs to guarantee no aliasing - // CHECK: %[[MEM_A:.*]] = memref.alloc() - %memrefA = memref.alloc() : memref<1024xf32> - // CHECK: %[[MEM_B:.*]] = memref.alloc() - %memrefB = memref.alloc() : memref<1024xf32> - - // Store to memory A - // CHECK: vector.store %{{.*}}, %[[MEM_A]] - vector.store %data, %memrefA[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to memory B (intervening operation, different memref) - // CHECK: vector.store %{{.*}}, %[[MEM_B]] - vector.store %data, %memrefB[%offset] : memref<1024xf32>, vector<4xf32> - - // Load from memory A - RAW dependency with store to A at distance 1 - // CHECK: amdgpu.memory_counter_wait load(1) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM_A]] - %result = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @workgroup_memory_raw -func.func @workgroup_memory_raw(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Allocate workgroup (LDS) memory - // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> - %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> - - // Store to LDS - // CHECK: vector.store %{{.*}}, %[[LDS]] - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Load from LDS - RAW dependency, should use dsCnt not loadCnt - // CHECK: amdgpu.memory_counter_wait ds(0) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[LDS]] - %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait ds(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @mixed_global_and_workgroup -// CHECK-SAME: (%[[GLOBAL:.*]]: memref<1024xf32>, %[[LDS:.*]]: memref<1024xf32, #gpu.address_space>, %{{.*}}: vector<4xf32>, %{{.*}}: index) -func.func @mixed_global_and_workgroup(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Store to global memory - // CHECK: vector.store %{{.*}}, %[[GLOBAL]] - vector.store %data, %global[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to LDS (different counter, no dependency) - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: vector.store %{{.*}}, %[[LDS]] - vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Load from global - RAW dependency with global store at distance 0 - // (LDS store doesn't count because it's a different counter type) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[GLOBAL]] - %result = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @existing_waitcnt -func.func @existing_waitcnt(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Store to memory - // CHECK: vector.store - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Existing wait operation - should clear pending operations - // CHECK: amdgpu.memory_counter_wait load(0) - amdgpu.memory_counter_wait load(0) - - // Another store after the wait - // CHECK: vector.store - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Load requires wait for the second store only (first was already waited on) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load - %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @existing_waitcnt_more_strict -func.func @existing_waitcnt_more_strict(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { - %memref1 = memref.alloc() : memref<1024xf32> - %memref2 = memref.alloc() : memref<1024xf32> - - // Store to memory - // CHECK: vector.store - // CHECK: vector.store - vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> - vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> - - // Existing wait operation - should clear pending operations - // Normally, the distance will be 1, but explicit amdgpu.memory_counter_wait - // overrides it. - // CHECK-NOT: amdgpu.memory_counter_wait load(1) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NOT: amdgpu.memory_counter_wait load(1) - amdgpu.memory_counter_wait load(0) - - // CHECK: %[[LOAD:.*]] = vector.load - %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - - -// CHECK-LABEL: func.func @control_flow_merge -func.func @control_flow_merge(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - %memref1 = memref.alloc() : memref<1024xf32> - %memref2 = memref.alloc() : memref<1024xf32> - - // Common operation before branching - // CHECK: vector.store - vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: cf.cond_br - cf.cond_br %cond, ^bb1, ^bb2 - -^bb1: - // Extra operation in this path - // CHECK: vector.store - vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: cf.br - cf.br ^bb3 - -^bb2: - // No extra operations, just branch to merge point - // CHECK: cf.br - cf.br ^bb3 - -^bb3: - // bb1 branch has distance 1 but bb2 has distance 0, so we need to conservatively - // take 0 - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load - %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @control_flow_merge_same_lists -func.func @control_flow_merge_same_lists(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - %memref1 = memref.alloc() : memref<1024xf32> - %memref2 = memref.alloc() : memref<1024xf32> - - // Common operation before branching - // CHECK: vector.store - vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: cf.cond_br - cf.cond_br %cond, ^bb1, ^bb2 - -^bb1: - // CHECK: vector.store - vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: cf.br - cf.br ^bb3 - -^bb2: - vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: cf.br - cf.br ^bb3 - -^bb3: - // both branches has the same distance 1 - // CHECK: amdgpu.memory_counter_wait load(1) - // CHECK-NEXT: %[[LOAD:.*]] = vector.load - %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[LOAD]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @loop_carried_dependency -func.func @loop_carried_dependency(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // CHECK: scf.for - %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { - // Store in each iteration - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: vector.store - vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Load in the same iteration - RAW dependency with store from this iteration - // In steady state, the backedge brings pending operations from previous iteration - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOADED:.*]] = vector.load - %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Yield uses the load result, which is async, so need to wait for it - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: scf.yield %[[LOADED]] - scf.yield %loaded : vector<4xf32> - } - - // CHECK: return - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @loop_load_before_store -func.func @loop_load_before_store(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // CHECK: scf.for - %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { - // Load first - in steady state, has RAW dependency with store from previous iteration - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[LOADED:.*]] = vector.load - %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Store after load - WAR dependency with load in same iteration - // The wait for the load clears it from pending, so this wait is for the load - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: vector.store - vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> - - // Yield uses load result - load was already waited on by the store, no additional wait needed - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: scf.yield %[[LOADED]] - scf.yield %loaded : vector<4xf32> - } - - // CHECK: return - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @memref_copy_raw_source -func.func @memref_copy_raw_source(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { - // Store to source - // CHECK: vector.store - vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> - - // Copy from source - RAW dependency (reads from source that was just written) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: memref.copy - memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> - - // CHECK: return - return -} - -// CHECK-LABEL: func.func @memref_copy_waw_target -func.func @memref_copy_waw_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { - // Store to destination - // CHECK: vector.store - vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // Copy to destination - WAW dependency (writes to target that was just written) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: memref.copy - memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> - - // CHECK: return - return -} - -// CHECK-LABEL: func.func @memref_copy_war_target -func.func @memref_copy_war_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // Load from destination - // CHECK: %[[RESULT:.*]] = vector.load - %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // Copy to destination - WAR dependency (writes to target that was just read) - // The copy's wait also synchronizes the load, so return doesn't need another wait - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: memref.copy - memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> - - // CHECK: return %[[RESULT]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @memref_copy_both_dependencies -func.func @memref_copy_both_dependencies(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Store to source - // CHECK: vector.store - vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> - - // Store to destination - // CHECK: vector.store - vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // Copy needs to wait for both stores: - // - RAW on source (copy reads from source) - // - WAW on target (copy writes to destination) - // Both stores alias with their respective memrefs, so we need wait(0) - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: memref.copy - memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> - - // Load from destination after copy - RAW dependency with copy - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: %[[RESULT:.*]] = vector.load - %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: return %[[RESULT]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @gather_to_lds -func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %src_offset: index, %dst_offset: index) -> vector<4xf32> { - // Store to global memory - // CHECK: vector.store - vector.store %data, %global[%src_offset] : memref<1024xf32>, vector<4xf32> - - // Gather from global to LDS - has both RAW (reads from global) and acts as store to LDS - // Should wait for global store using load counter - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: amdgpu.gather_to_lds - amdgpu.gather_to_lds %global[%src_offset], %lds[%dst_offset] : f32, memref<1024xf32>, memref<1024xf32, #gpu.address_space> - - // Load from LDS - RAW dependency with gather writing to LDS - // Should wait for LDS operation using ds counter - // CHECK: amdgpu.memory_counter_wait ds(0) - // CHECK-NEXT: %[[RESULT:.*]] = vector.load - %result = vector.load %lds[%dst_offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait ds(0) - // CHECK-NEXT: return %[[RESULT]] - return %result : vector<4xf32> -} - -// CHECK-LABEL: func.func @double_buffering -func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { - %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> - - %out = memref.alloc() : memref<1024xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK: scf.for - scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // Skip the second buffer copy - // CHECK: amdgpu.memory_counter_wait ds(1) - // CHECK: vector.load - %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Cannot skip unfortunately - // CHECK: amdgpu.memory_counter_wait load(0) ds(0) - // CHECK: vector.store - vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: scf.yield - scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> - } - - // CHECK: return - return -} - -// CHECK-LABEL: func.func @triple_buffering -func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { - %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> - - %out = memref.alloc() : memref<1024xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK: scf.for - scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { - // Skip the second buffer copy - // CHECK: amdgpu.memory_counter_wait ds(1) - // CHECK: vector.load - %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // Skip the prev copy - // CHECK: amdgpu.memory_counter_wait load(0) ds(1) - // CHECK: vector.store - vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: scf.yield - scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> - } - - // CHECK: return - return -} - - -// CHECK-LABEL: func.func @triple_buffering_reg_space -func.func @triple_buffering_reg_space(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { - %c0 = arith.constant 0 : index - %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %reg = memref.alloca() : memref<4xf32, 128 : i32> - - %out = memref.alloc() : memref<1024xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK: scf.for - scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // Skip the the prev copy - // CHECK: amdgpu.memory_counter_wait ds(1) - // CHECK: vector.load - %data = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: vector.store - vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.subview - %subview = memref.subview %current[%offset] [4] [1] : memref<1024xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> - - // This copy only depends on buffer 2 iterations ago - // CHECK: amdgpu.memory_counter_wait ds(2) - // CHECK: memref.copy - memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: scf.yield - scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> - } - - // CHECK: return - return -} - -// CHECK-LABEL: func.func @load_store_repeated -func.func @load_store_repeated(%src0: memref<4xf32>, %src1: memref<4xf32>, %offset: index) { - %c0 = arith.constant 0 : index - %buff0 = memref.alloc() : memref<4xf32, #gpu.address_space> - %buff1 = memref.alloc() : memref<4xf32, #gpu.address_space> - %reg0 = memref.alloca() : memref<4xf32, 128 : i32> - %reg1 = memref.alloca() : memref<4xf32, 128 : i32> - %reg2 = memref.alloca() : memref<4xf32, 128 : i32> - %reg3 = memref.alloca() : memref<4xf32, 128 : i32> - - // CHECK-COUNT-4: memref.copy - memref.copy %src0, %reg0 : memref<4xf32> to memref<4xf32, 128 : i32> - memref.copy %src1, %reg1 : memref<4xf32> to memref<4xf32, 128 : i32> - - memref.copy %buff0, %reg2 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> - memref.copy %buff1, %reg3 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> - - // CHECK: amdgpu.memory_counter_wait load(1) - // CHECK-NEXT: vector.load - %data0 = vector.load %reg0[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK: amdgpu.memory_counter_wait load(0) - // CHECK-NEXT: vector.load - %data1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - - // CHECK: amdgpu.memory_counter_wait ds(1) - // CHECK-NEXT: vector.load - %data2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK: amdgpu.memory_counter_wait ds(0) - // CHECK-NEXT: vector.load - %data3 = vector.load %reg3[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - - return -} From 5520e643db0da506cd630565845f9345d2576ebb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 22:45:45 +0100 Subject: [PATCH 07/38] test Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 3 ++- water/test/Transforms/insert-waitcnt.mlir | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 1c9123aff3..635c094518 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -450,9 +450,10 @@ class WaitcntState : public AbstractDenseLattice { // RAW: current load after pending store // WAR: current store after pending load // WAW: current store after pending store + // We don't care about WAW dependencies for now. bool hasRAW = isCurrentLoad && isPendingStore; bool hasWAR = isCurrentStore && isPendingLoad; - bool hasWAW = isCurrentStore && isPendingStore; + bool hasWAW = false; // isCurrentStore && isPendingStore; if (hasRAW || hasWAR || hasWAW) { // Found dependency - compute requirement by counting forward from diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 987743c2a2..338517d824 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -1 +1,22 @@ // RUN: water-opt %s --water-insert-waitcnt | FileCheck %s + +// CHECK-LABEL: func.func @no_dependency +func.func @no_dependency(%global1: memref<64x64xf32>, %global2: memref<64x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // First load to LDS1. + %base1 = amdgpu.make_dma_base %global1[%c0, %c0], %lds1[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc1 = amdgpu.make_dma_descriptor %base1 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc1 : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Second load to different LDS2 (no dependency). + %base2 = amdgpu.make_dma_base %global2[%c0, %c0], %lds2[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc2 = amdgpu.make_dma_descriptor %base2 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + // CHECK-NOT: amdgpu.memory_counter_wait + amdgpu.tensor_load_to_lds %desc2 : !amdgpu.tdm_descriptor + + return +} From 0204e982afc3f8a3cebfde7ee737ff0b39857754 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 22:52:39 +0100 Subject: [PATCH 08/38] test Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 7 ++++++- water/test/Transforms/insert-waitcnt.mlir | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 635c094518..2703d4a36c 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -56,6 +56,11 @@ static std::optional propagateTensorDesc(Value value, bool isLoad) { /// Check if the operation is a load operation and return the base memref. static std::optional isLoadOp(Operation *op) { + // Use interface + if (auto load = dyn_cast(op)) + return load.getBase(); + if (auto load = dyn_cast(op)) + return load.getMemref(); if (auto load = dyn_cast(op)) return propagateTensorDesc(load.getDesc(), true); @@ -200,7 +205,7 @@ struct WaitcntRequirement { static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { WaitcntRequirement req; - if (isa(op)) + if (isa(op)) req.tensor_cnt = zero ? 0 : 1; return req; diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 338517d824..68aea26891 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -20,3 +20,23 @@ func.func @no_dependency(%global1: memref<64x64xf32>, %global2: memref<64x64xf32 return } + +// CHECK-LABEL: func.func @raw_dependency_vector_load +func.func @raw_dependency_vector_load(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // Tensor load to LDS (write to LDS). + %base = amdgpu.make_dma_base %global[%c0, %c0], %lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Vector load from LDS (read from LDS) - creates RAW dependency. + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK: amdgpu.lds_barrier + %vec = vector.load %lds[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} From 08febd1f058867fc14237eb0a8f580bc76fa0e22 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 22:54:21 +0100 Subject: [PATCH 09/38] test Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 68aea26891..fd1b1f2bbc 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -40,3 +40,28 @@ func.func @raw_dependency_vector_load(%global: memref<64x64xf32>, %lds: memref<6 return } + +// CHECK-LABEL: func.func @multiple_pending_ops +func.func @multiple_pending_ops(%global: memref<64x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // First tensor load to LDS1. + %base1 = amdgpu.make_dma_base %global[%c0, %c0], %lds1[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc1 = amdgpu.make_dma_descriptor %base1 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc1 : !amdgpu.tdm_descriptor + + // Second tensor load to LDS2. + %base2 = amdgpu.make_dma_base %global[%c0, %c0], %lds2[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc2 = amdgpu.make_dma_descriptor %base2 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc2 : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Vector load from LDS1 - has RAW dependency with first load, should wait for 1 (second op is still pending). + // CHECK: amdgpu.memory_counter_wait tensor(1) + // CHECK: amdgpu.lds_barrier + %vec = vector.load %lds1[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} From 9e20f10df3f19f4f519f178f8d9585759199fbac Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 22:55:41 +0100 Subject: [PATCH 10/38] loop test Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index fd1b1f2bbc..a92150c353 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -65,3 +65,28 @@ func.func @multiple_pending_ops(%global: memref<64x64xf32>, %lds1: memref<64x64x return } + +// CHECK-LABEL: func.func @scf_for_loop +func.func @scf_for_loop(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + // Tensor load to LDS before loop. + %base = amdgpu.make_dma_base %global[%c0, %c0], %lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Loop that reads from LDS - should insert wait before loop. + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK: amdgpu.lds_barrier + // CHECK: scf.for + scf.for %i = %c0 to %c4 step %c1 { + %vec = vector.load %lds[%i, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + } + + return +} From 58c716fb3152767283761ba28eb4da8ae4f772d7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 23:01:17 +0100 Subject: [PATCH 11/38] double buffer Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index a92150c353..7f1ef26d95 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -90,3 +90,40 @@ func.func @scf_for_loop(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu return } + +// CHECK-LABEL: func.func @double_buffer_loop +func.func @double_buffer_loop(%global: memref<512x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + + // Initial load to buffer 1. + %base_init = amdgpu.make_dma_base %global[%c0, %c0], %lds1[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_init = amdgpu.make_dma_descriptor %base_init globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_init : !amdgpu.tdm_descriptor + + // Double buffer loop: load next chunk while processing current chunk. + // CHECK: scf.for + scf.for %i = %c0 to %c8 step %c1 iter_args(%current_buf = %lds1, %next_buf = %lds2) -> (memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>) { + // Load next data to next_buf. + %next_idx = arith.addi %i, %c1 : index + %global_offset = arith.muli %next_idx, %c64 : index + %base_next = amdgpu.make_dma_base %global[%global_offset, %c0], %next_buf[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_next = amdgpu.make_dma_descriptor %base_next globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_next : !amdgpu.tdm_descriptor + + // Barrier to ensure previous load completed. + // CHECK: amdgpu.memory_counter_wait tensor(1) + // CHECK-NEXT: amdgpu.lds_barrier + amdgpu.lds_barrier + + // Process current buffer (read from the buffer loaded in previous iteration). + %vec = vector.load %current_buf[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + // Swap buffers for next iteration. + scf.yield %next_buf, %current_buf : memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space> + } + + return +} From fcd3b4fba46ee7cd6a8428e77c496c5f10b560fd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 23:08:07 +0100 Subject: [PATCH 12/38] triple buffering Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 7f1ef26d95..b257c1521b 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -127,3 +127,47 @@ func.func @double_buffer_loop(%global: memref<512x64xf32>, %lds1: memref<64x64xf return } + +// CHECK-LABEL: func.func @triple_buffer_loop +func.func @triple_buffer_loop(%global: memref<512x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>, %lds3: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + + // Initial load to buffer 1. + %base_init1 = amdgpu.make_dma_base %global[%c0, %c0], %lds1[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_init1 = amdgpu.make_dma_descriptor %base_init1 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_init1 : !amdgpu.tdm_descriptor + + // Initial load to buffer 2. + %base_init2 = amdgpu.make_dma_base %global[%c64, %c0], %lds2[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_init2 = amdgpu.make_dma_descriptor %base_init2 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_init2 : !amdgpu.tdm_descriptor + + // Triple buffer loop: load next chunk while processing current chunk, with two extra buffers in flight. + // CHECK: scf.for + scf.for %i = %c0 to %c8 step %c1 iter_args(%current_buf = %lds1, %next_buf = %lds2, %next_next_buf = %lds3) -> (memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>) { + // Load next data to next_next_buf (2 iterations ahead). + %next_idx = arith.addi %i, %c1 : index + %global_offset = arith.muli %next_idx, %c64 : index + %base_next = amdgpu.make_dma_base %global[%global_offset, %c0], %next_next_buf[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_next = amdgpu.make_dma_descriptor %base_next globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_next : !amdgpu.tdm_descriptor + + // Barrier to ensure load from 2 iterations ago completed. + // With triple buffering, we can have 2 loads in flight (current + previous iteration). + // Wait until at most 2 remain, ensuring the oldest (from 2 iters ago) is done. + // CHECK: amdgpu.memory_counter_wait tensor(2) + // CHECK-NEXT: amdgpu.lds_barrier + amdgpu.lds_barrier + + // Process current buffer (read from the buffer loaded 2 iterations ago). + %vec = vector.load %current_buf[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + // Rotate buffers for next iteration. + scf.yield %next_buf, %next_next_buf, %current_buf : memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space> + } + + return +} From c8ec468eaf94fdc86ffd0a76173ecf8dc7d46edd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 9 Jan 2026 23:18:41 +0100 Subject: [PATCH 13/38] only track tensor ops Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 9 +++++--- water/test/Transforms/insert-waitcnt.mlir | 24 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 2703d4a36c..e5dda1d96b 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -41,6 +41,10 @@ static std::optional propagateViewOps(Value value) { return value; } +static bool trackOp(Operation *op) { + return isa(op); +} + static std::optional propagateTensorDesc(Value value, bool isLoad) { auto makeDesc = value.getDefiningOp(); if (!makeDesc) @@ -639,9 +643,8 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LDBG() << " No operation requirement"; } - // Check if this is an async memory operation (vector load/store) - if (WaitcntRequirement::getOperationRequirement(op, false) - .hasRequirement()) { + // Check if this is an async memory operation + if (trackOp(op)) { // Add this operation to the pending list newState.addPendingOp(op); } diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index b257c1521b..a07908d52f 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -171,3 +171,27 @@ func.func @triple_buffer_loop(%global: memref<512x64xf32>, %lds1: memref<64x64xf return } + +// CHECK-LABEL: func.func @vector_ops_only +func.func @vector_ops_only(%lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cst = arith.constant dense<0.0> : vector<4xf32> + + // Loop with vector load/store but no tensor operations. + // CHECK: scf.for + // CHECK-NOT: amdgpu.memory_counter_wait + scf.for %i = %c0 to %c4 step %c1 { + // Vector load from LDS1. + %vec = vector.load %lds1[%i, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + // Barrier. + amdgpu.lds_barrier + + // Vector store to LDS2. + vector.store %vec, %lds2[%i, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + } + + return +} From 8e4a1b9013f257fcaa66e7087d5d479760c172b1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 10 Jan 2026 23:25:53 +0100 Subject: [PATCH 14/38] comments Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index e5dda1d96b..fdd3a73e4d 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -41,6 +41,7 @@ static std::optional propagateViewOps(Value value) { return value; } +/// Check if we need to track the operation for waitcnt requirements. static bool trackOp(Operation *op) { return isa(op); } @@ -504,15 +505,17 @@ class WaitcntState : public AbstractDenseLattice { } private: - /// Pending asynchronous operations + /// Pending asynchronous operations. SmallVector, 4> pendingOpsLists; - /// Required waitcnt after this state + /// Required waitcnt after this state. WaitcntRequirement requirement; + /// Cached sets of pending operations and tokens for quick lookup. mutable llvm::SmallDenseSet pendingOpsSet; mutable llvm::SmallDenseSet pendingOpsTokens; + /// Copy on write for pending operations lists. void cow() { for (auto &pendingOps : pendingOpsLists) { if (pendingOps.use_count() > 1) { @@ -524,6 +527,7 @@ class WaitcntState : public AbstractDenseLattice { } } + /// Check if the operation or value is in pending operations lists. bool isPendingOp(llvm::PointerUnion opOrVal) const { if (pendingOpsLists.empty()) return false; From efd0a978f430def3333678eb0fab5645a3e3caf7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 10 Jan 2026 23:31:44 +0100 Subject: [PATCH 15/38] list of values Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 63 +++++++++++---------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index fdd3a73e4d..91d9de2736 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -59,25 +59,30 @@ static std::optional propagateTensorDesc(Value value, bool isLoad) { return propagateViewOps(isLoad ? makeBase.getGlobal() : makeBase.getLds()); } -/// Check if the operation is a load operation and return the base memref. -static std::optional isLoadOp(Operation *op) { - // Use interface - if (auto load = dyn_cast(op)) - return load.getBase(); - if (auto load = dyn_cast(op)) - return load.getMemref(); - if (auto load = dyn_cast(op)) - return propagateTensorDesc(load.getDesc(), true); - - return std::nullopt; +/// Check if the operation is a load operation and return list of base memrefs +/// to track. +static SmallVector isLoadOp(Operation *op) { + SmallVector result; + if (auto load = dyn_cast(op)) { + result.push_back(load.getBase()); + } else if (auto load = dyn_cast(op)) { + result.push_back(load.getMemref()); + } else if (auto load = dyn_cast(op)) { + if (auto memref = propagateTensorDesc(load.getDesc(), true)) + result.push_back(*memref); + } + return result; } -/// Check if the operation is a store operation and return the base memref. -static std::optional isStoreOp(Operation *op) { - if (auto store = dyn_cast(op)) - return propagateTensorDesc(store.getDesc(), false); - - return std::nullopt; +/// Check if the operation is a store operation and return list of base memrefs +/// to track. +static SmallVector isStoreOp(Operation *op) { + SmallVector result; + if (auto store = dyn_cast(op)) { + if (auto memref = propagateTensorDesc(store.getDesc(), false)) + result.push_back(*memref); + } + return result; } template @@ -105,11 +110,11 @@ struct PendingOperations { ops.push_back(op); auto &back = opsTokens.emplace_back(); - if (auto memref = isStoreOp(op)) - back.push_back(*memref); + for (Value memref : isStoreOp(op)) + back.push_back(memref); - if (auto memref = isLoadOp(op)) - back.push_back(*memref); + for (Value memref : isLoadOp(op)) + back.push_back(memref); return back; } @@ -486,10 +491,10 @@ class WaitcntState : public AbstractDenseLattice { return pendingResult; }; - if (auto loadBase = isLoadOp(pendingOp)) - result.merge(checkPendingMemref(*loadBase, true, false)); - if (auto storeBase = isStoreOp(pendingOp)) - result.merge(checkPendingMemref(*storeBase, false, true)); + for (Value loadBase : isLoadOp(pendingOp)) + result.merge(checkPendingMemref(loadBase, true, false)); + for (Value storeBase : isStoreOp(pendingOp)) + result.merge(checkPendingMemref(storeBase, false, true)); } } @@ -497,10 +502,10 @@ class WaitcntState : public AbstractDenseLattice { }; // TODO: atomics will have both load and store flags set WaitcntRequirement result; - if (auto loadBase = isLoadOp(op)) - result.merge(checkMemref(*loadBase, true, false)); - if (auto storeBase = isStoreOp(op)) - result.merge(checkMemref(*storeBase, false, true)); + for (Value loadBase : isLoadOp(op)) + result.merge(checkMemref(loadBase, true, false)); + for (Value storeBase : isStoreOp(op)) + result.merge(checkMemref(storeBase, false, true)); return result; } From 65e2987fc77ddb3bd31298d1b9d2ae665d595e15 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 10 Jan 2026 23:37:05 +0100 Subject: [PATCH 16/38] select Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 29 ++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 91d9de2736..db6464088f 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -9,6 +9,7 @@ #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -34,13 +35,30 @@ static bool isBarrier(Operation *op) { } /// Try to propagate view operations to the base memref. -static std::optional propagateViewOps(Value value) { +static Value propagateViewOps(Value value) { while (auto view = value.getDefiningOp()) value = view.getViewSource(); return value; } +/// Collect all underlying values through view and select operations. +static SmallVector collectUnderlyingValues(Value value) { + SmallVector result; + SmallVector worklist; + worklist.push_back(value); + while (!worklist.empty()) { + Value current = propagateViewOps(worklist.pop_back_val()); + if (auto select = current.getDefiningOp()) { + worklist.push_back(select.getTrueValue()); + worklist.push_back(select.getFalseValue()); + } else { + result.push_back(current); + } + } + return result; +} + /// Check if we need to track the operation for waitcnt requirements. static bool trackOp(Operation *op) { return isa(op); @@ -62,16 +80,15 @@ static std::optional propagateTensorDesc(Value value, bool isLoad) { /// Check if the operation is a load operation and return list of base memrefs /// to track. static SmallVector isLoadOp(Operation *op) { - SmallVector result; if (auto load = dyn_cast(op)) { - result.push_back(load.getBase()); + return collectUnderlyingValues(load.getBase()); } else if (auto load = dyn_cast(op)) { - result.push_back(load.getMemref()); + return collectUnderlyingValues(load.getMemref()); } else if (auto load = dyn_cast(op)) { if (auto memref = propagateTensorDesc(load.getDesc(), true)) - result.push_back(*memref); + return collectUnderlyingValues(*memref); } - return result; + return {}; } /// Check if the operation is a store operation and return list of base memrefs From adc37c53d4631b3c5e9f7fb328ec962c47349460 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 10 Jan 2026 23:38:42 +0100 Subject: [PATCH 17/38] cleanup Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index db6464088f..d632eea722 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -64,7 +64,8 @@ static bool trackOp(Operation *op) { return isa(op); } -static std::optional propagateTensorDesc(Value value, bool isLoad) { +/// Try to get the base memref from the tensor descriptor. +static Value propagateTensorDesc(Value value, bool isLoad) { auto makeDesc = value.getDefiningOp(); if (!makeDesc) return value; @@ -85,8 +86,8 @@ static SmallVector isLoadOp(Operation *op) { } else if (auto load = dyn_cast(op)) { return collectUnderlyingValues(load.getMemref()); } else if (auto load = dyn_cast(op)) { - if (auto memref = propagateTensorDesc(load.getDesc(), true)) - return collectUnderlyingValues(*memref); + Value memref = propagateTensorDesc(load.getDesc(), true); + return collectUnderlyingValues(memref); } return {}; } @@ -96,8 +97,8 @@ static SmallVector isLoadOp(Operation *op) { static SmallVector isStoreOp(Operation *op) { SmallVector result; if (auto store = dyn_cast(op)) { - if (auto memref = propagateTensorDesc(store.getDesc(), false)) - result.push_back(*memref); + Value memref = propagateTensorDesc(store.getDesc(), false); + return collectUnderlyingValues(memref); } return result; } From 345cf041720115734d44afd6d8112f4d55dbae77 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 10 Jan 2026 23:42:15 +0100 Subject: [PATCH 18/38] select tests Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 46 +++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index a07908d52f..55ea5361a4 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -195,3 +195,49 @@ func.func @vector_ops_only(%lds1: memref<64x64xf32, #gpu.address_space, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>, %cond: i1) { + %c0 = arith.constant 0 : index + + // Tensor load to LDS1. + %base1 = amdgpu.make_dma_base %global[%c0, %c0], %lds1[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc1 = amdgpu.make_dma_descriptor %base1 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc1 : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Select between LDS1 and LDS2 based on condition. + %selected = arith.select %cond, %lds1, %lds2 : memref<64x64xf32, #gpu.address_space> + + // Vector load from selected buffer - should detect dependency with LDS1. + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK-NEXT: amdgpu.lds_barrier + %vec = vector.load %selected[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} + +// CHECK-LABEL: func.func @select_tensor_base +func.func @select_tensor_base(%global: memref<64x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>, %cond: i1) { + %c0 = arith.constant 0 : index + + // Select which LDS buffer to use for tensor load. + %selected_lds = arith.select %cond, %lds1, %lds2 : memref<64x64xf32, #gpu.address_space> + + // Tensor load using selected LDS buffer - writes to either LDS1 or LDS2. + %base = amdgpu.make_dma_base %global[%c0, %c0], %selected_lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // Vector load from LDS1 - should detect dependency since tensor load might have written here. + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK-NEXT: amdgpu.lds_barrier + %vec = vector.load %lds1[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} From 66f9850af0a7b941a4170de6d624dc68753a3edb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 11 Jan 2026 00:02:43 +0100 Subject: [PATCH 19/38] desc Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 128 ++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index d632eea722..446c3394dd 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -4,6 +4,134 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// WaterInsertWaitcnt Pass - Algorithm Summary +//===----------------------------------------------------------------------===// +// +// This pass inserts memory counter wait instructions +// (amdgpu.memory_counter_wait) to ensure proper synchronization between +// asynchronous memory operations on AMD GPUs. It is analogous to LLVM's +// SIInsertWaitcnts pass but operates at the MLIR level. +// +// ## Overview +// +// AMD GPUs execute certain memory operations asynchronously, particularly +// tensor loads (global memory → LDS via DMA). These operations can complete +// out of order, requiring explicit synchronization via wait instructions. +// This pass detects memory dependencies and inserts the minimal set of waits +// needed for correctness. +// +// ## Key Concepts +// +// 1. **Tracked Operations**: Currently tracks `amdgpu.tensor_load_to_lds` +// operations as asynchronous. These operations read from global memory +// and write to LDS (Local Data Share). +// +// 2. **Memory Dependencies**: Detects three types of hazards: +// - RAW (Read After Write): Reading from a location being written by a +// pending async operation +// - WAR (Write After Read): Writing to a location being read by a pending +// async operation +// - WAW (Write After Write): Currently disabled, as tensor operations to +// the same LDS location can be allowed to overlap +// +// 3. **Barriers**: Operations like `amdgpu.lds_barrier` and `gpu.barrier` +// serve as synchronization points. Waits are inserted at barriers to +// ensure pending operations complete before proceeding. +// +// 4. **Wait Counts**: The `tensor_cnt` parameter specifies how many operations +// should remain pending. For example: +// - `tensor(0)`: Wait for all pending operations to complete +// - `tensor(1)`: Wait until at most 1 operation remains pending +// - `tensor(2)`: Wait until at most 2 operations remain pending +// +// ## Algorithm Details +// +// ### Phase 1: Dataflow Analysis +// +// Uses dense forward dataflow analysis (DenseForwardDataFlowAnalysis) to +// propagate state through the program: +// +// **State (WaitcntState)**: +// - `pendingOpsLists`: Lists of pending asynchronous operations along with +// memory tokens (memrefs they touch). Multiple lists handle different +// control flow paths. +// - `requirement`: The waitcnt requirement needed after this program point. +// +// **Transfer Function** (visitOperation): +// 1. For barriers: Add to pending operations list (barriers separate groups +// of async operations) +// 2. For memory operations: Check if they access memory touched by pending +// operations +// 3. If dependency found: Compute wait count needed and propagate requirement +// backwards to the barrier +// 4. For tracked operations: Add to pending list for subsequent operations +// +// **Join Operation**: +// - Merges states from different control flow paths +// - Takes conservative approach: keeps all unique pending operation sequences +// - Merges requirements by taking minimum (most restrictive) +// +// **Control Flow Handling** (visitRegionBranchControlFlowTransfer): +// - Propagates memory tokens through loop iter_args and branch results +// - Maps values across region boundaries +// - Maintains dominance information to determine which tokens remain valid +// +// ### Phase 2: Memory Reference Resolution +// +// To track which memory locations operations touch: +// +// 1. **View Propagation** (`propagateViewOps`): Strips away view operations +// (subview, reinterpret_cast) to get base memrefs. +// +// 2. **Select Handling** (`collectUnderlyingValues`): When memory references +// flow through `arith.select`, conservatively tracks all possible values +// (both true and false branches). This handles dynamic buffer selection in +// double/triple buffering patterns. +// +// 3. **Tensor Descriptor Unwrapping** (`propagateTensorDesc`): For tensor +// operations, extracts the actual memref from the DMA descriptor chain: +// TensorLoadToLDSOp → MakeDmaDescriptorOp → MakeDmaBaseOp → memref +// +// ### Phase 3: Dependency Detection +// +// For each operation accessing memory (`checkMemoryDependency`): +// 1. Extract all memory references (handling selects) +// 2. For each pending operation in reverse order (most recent first): +// - Check if memories may alias (same address space + memory in tokens) +// - Detect RAW/WAR hazards +// - If hazard found: Count operations from dependency point to end of list +// (this count determines how many operations can remain pending) +// - Track which barrier separates the operations +// +// ### Phase 4: Wait Insertion +// +// After analysis completes: +// 1. Walk all operations +// 2. Check if operation has a waitcnt requirement (from analysis) +// 3. Insert `amdgpu.memory_counter_wait` with computed tensor_cnt before +// the operation (typically a barrier) +// +// ## Example: Double Buffering +// +// ```mlir +// tensor_load_to_lds %desc1 // Load to buffer1 +// scf.for ... iter_args(%curr = %buf1, %next = %buf2) { +// tensor_load_to_lds %desc_next // Load to next buffer +// // Now 2 operations pending: desc1 and desc_next +// amdgpu.lds_barrier +// // Need tensor(1): wait until only 1 remains (ensures desc1 done) +// %vec = vector.load %curr +// scf.yield %next, %curr +// } +// ``` +// +// The analysis detects RAW between `tensor_load_to_lds %desc1` and +// `vector.load %curr` (when curr=buf1), counts 1 operation after desc1, +// and inserts `memory_counter_wait tensor(1)` at the barrier. +// +//===----------------------------------------------------------------------===// + #include "water/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" From 64358d334424238ced3d6e2731e9acf9e30a66ea Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 11 Jan 2026 15:00:59 +0100 Subject: [PATCH 20/38] fix Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 446c3394dd..19bd87230e 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -16,8 +16,8 @@ // ## Overview // // AMD GPUs execute certain memory operations asynchronously, particularly -// tensor loads (global memory → LDS via DMA). These operations can complete -// out of order, requiring explicit synchronization via wait instructions. +// tensor loads (global memory → LDS via DMA). These operations require explicit +// synchronization via wait instructions. // This pass detects memory dependencies and inserts the minimal set of waits // needed for correctness. // From 9ef94cef8a437d3720a3f6d87601102b8802a708 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 11 Jan 2026 15:31:26 +0100 Subject: [PATCH 21/38] barrier wait Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 10 +++++---- water/test/Transforms/insert-waitcnt.mlir | 24 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 19bd87230e..b62956a8bb 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -35,9 +35,10 @@ // - WAW (Write After Write): Currently disabled, as tensor operations to // the same LDS location can be allowed to overlap // -// 3. **Barriers**: Operations like `amdgpu.lds_barrier` and `gpu.barrier` -// serve as synchronization points. Waits are inserted at barriers to -// ensure pending operations complete before proceeding. +// 3. **Barriers**: Operations like `amdgpu.lds_barrier`, `gpu.barrier`, and +// `rocdl.s.barrier.signal` serve as synchronization points. Waits are +// inserted at barriers to ensure pending operations complete before +// proceeding. // // 4. **Wait Counts**: The `tensor_cnt` parameter specifies how many operations // should remain pending. For example: @@ -139,6 +140,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dominance.h" @@ -159,7 +161,7 @@ namespace mlir::water { namespace { static bool isBarrier(Operation *op) { - return isa(op) || isa(op); + return isa(op); } /// Try to propagate view operations to the base memref. diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 55ea5361a4..ef58d68b9a 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -66,6 +66,30 @@ func.func @multiple_pending_ops(%global: memref<64x64xf32>, %lds1: memref<64x64x return } +// CHECK-LABEL: func.func @rocdl_barrier_signal +func.func @rocdl_barrier_signal(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // Tensor load to LDS. + %base = amdgpu.make_dma_base %global[%c0, %c0], %lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // ROCDL barrier signal/wait. + rocdl.s.barrier.signal id = 0 + rocdl.s.barrier.wait id = 0 + + // Vector load from LDS - creates RAW dependency. + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK-NEXT: rocdl.s.barrier.signal id = 0 + // CHECK-NEXT: rocdl.s.barrier.wait id = 0 + %vec = vector.load %lds[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + + return +} + + // CHECK-LABEL: func.func @scf_for_loop func.func @scf_for_loop(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { %c0 = arith.constant 0 : index From c21e208ce81b129e90eb543a6d47619066b024cb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 12 Jan 2026 16:07:42 +0100 Subject: [PATCH 22/38] integrate Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index b9b277e469..b81dfd09ba 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -428,6 +428,7 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] toolkit_path = get_water_mlir_pkg_path() pipeline = [ + "water-insert-waitcnt", "water-memref-decomposition", *add_opt(canonicalize_cse), "lower-affine", From c830748bf93c1a0bae27ba8902cb818045e0ab8f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 12 Jan 2026 16:18:44 +0100 Subject: [PATCH 23/38] update desc Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index fb83867037..06b9d2e0da 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -176,14 +176,10 @@ def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { wait/synchronization instructions to ensure memory operations complete before their results are used. - The pass tracks dependencies between memory operations and register uses, + The pass tracks dependencies between async memory operations, maintaining scoreboards to determine when waits are necessary. It handles: - Read-after-write (RAW) dependencies - - Write-after-write (WAW) dependencies - Write-after-read (WAR) dependencies - - This is analogous to LLVM's SIInsertWaitcnts pass but operates at the - MLIR level for AMDGPU dialect operations. }]; let dependentDialects = [ "::mlir::amdgpu::AMDGPUDialect", From 51de1ba12db2b9b4cddf332525a269c42cbe8b54 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 13 Jan 2026 01:07:48 +0100 Subject: [PATCH 24/38] fix deduplication Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 9 ++- water/test/Transforms/insert-waitcnt.mlir | 83 +++++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index b62956a8bb..3f3a5e6e74 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -514,6 +514,7 @@ class WaitcntState : public AbstractDenseLattice { /// Set the required waitcnt values void setRequirement(const WaitcntRequirement &req) { requirement = req; + cow(); for (auto &pendingOps : pendingOpsLists) { SmallVector newPending; SmallVector newPendingTokens; @@ -564,6 +565,7 @@ class WaitcntState : public AbstractDenseLattice { void updateTokens( llvm::function_ref &)> updateFunc) { + cow(); for (auto &pendingOps : pendingOpsLists) pendingOps->updateTokens(updateFunc); } @@ -668,7 +670,9 @@ class WaitcntState : public AbstractDenseLattice { mutable llvm::SmallDenseSet pendingOpsSet; mutable llvm::SmallDenseSet pendingOpsTokens; - /// Copy on write for pending operations lists. + /// List of pending ops are shared between multiple states to reduce memory + /// footprint. Call this before modifying the pending operations lists to + /// deduplicate them if necessary. void cow() { for (auto &pendingOps : pendingOpsLists) { if (pendingOps.use_count() > 1) { @@ -803,6 +807,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Check if this is an async memory operation if (trackOp(op)) { // Add this operation to the pending list + LDBG() << " Adding pending operation: " << *op; newState.addPendingOp(op); } @@ -872,7 +877,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Add token propagated through region control flow. if (Value mappedValue = valuesMapping.lookup(value)) - if (newTokens.empty() || newTokens.back() != mappedValue) + if (!llvm::is_contained(newTokens, mappedValue)) newTokens.push_back(mappedValue); }; newState.updateTokens(tokenUpdateFunc); diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index ef58d68b9a..45cc882860 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -196,6 +196,89 @@ func.func @triple_buffer_loop(%global: memref<512x64xf32>, %lds1: memref<64x64xf return } +// CHECK-LABEL: func.func @triple_buffer_loop_sync_before +func.func @triple_buffer_loop_sync_before(%global: memref<512x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>, %lds3: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + + // Initial load to buffer 1. + %base_init1 = amdgpu.make_dma_base %global[%c0, %c0], %lds1[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_init1 = amdgpu.make_dma_descriptor %base_init1 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_init1 : !amdgpu.tdm_descriptor + + // Initial load to buffer 2. + %base_init2 = amdgpu.make_dma_base %global[%c64, %c0], %lds2[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_init2 = amdgpu.make_dma_descriptor %base_init2 globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_init2 : !amdgpu.tdm_descriptor + + // Sync before the loop + amdgpu.memory_counter_wait tensor(0) + + // Triple buffer loop: load next chunk while processing current chunk, with two extra buffers in flight. + // CHECK: scf.for + scf.for %i = %c0 to %c8 step %c1 iter_args(%current_buf = %lds1, %next_buf = %lds2, %next_next_buf = %lds3) -> (memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>) { + // Load next data to next_next_buf (2 iterations ahead). + %next_idx = arith.addi %i, %c1 : index + %global_offset = arith.muli %next_idx, %c64 : index + %base_next = amdgpu.make_dma_base %global[%global_offset, %c0], %next_next_buf[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_next = amdgpu.make_dma_descriptor %base_next globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_next : !amdgpu.tdm_descriptor + + // Barrier to ensure load from 2 iterations ago completed. + // With triple buffering, we can have 2 loads in flight (current + previous iteration). + // Wait until at most 2 remain, ensuring the oldest (from 2 iters ago) is done. + // CHECK: amdgpu.memory_counter_wait tensor(2) + // CHECK-NEXT: amdgpu.lds_barrier + amdgpu.lds_barrier + + // Process current buffer (read from the buffer loaded 2 iterations ago). + %vec = vector.load %current_buf[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + // Rotate buffers for next iteration. + scf.yield %next_buf, %next_next_buf, %current_buf : memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space> + } + + return +} + +// CHECK-LABEL: func.func @triple_buffer_loop_no_init +func.func @triple_buffer_loop_no_init(%global: memref<512x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>, %lds3: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + + // Triple buffering without intializing the initial buffers + + // Triple buffer loop: load next chunk while processing current chunk, with two extra buffers in flight. + // CHECK: scf.for + scf.for %i = %c0 to %c8 step %c1 iter_args(%current_buf = %lds1, %next_buf = %lds2, %next_next_buf = %lds3) -> (memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>) { + // Load next data to next_next_buf (2 iterations ahead). + %next_idx = arith.addi %i, %c1 : index + %global_offset = arith.muli %next_idx, %c64 : index + %base_next = amdgpu.make_dma_base %global[%global_offset, %c0], %next_next_buf[%c0, %c0] : memref<512x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc_next = amdgpu.make_dma_descriptor %base_next globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc_next : !amdgpu.tdm_descriptor + + // Barrier to ensure load from 2 iterations ago completed. + // With triple buffering, we can have 2 loads in flight (current + previous iteration). + // Wait until at most 2 remain, ensuring the oldest (from 2 iters ago) is done. + // CHECK: amdgpu.memory_counter_wait tensor(2) + // CHECK-NEXT: amdgpu.lds_barrier + amdgpu.lds_barrier + + // Process current buffer (read from the buffer loaded 2 iterations ago). + %vec = vector.load %current_buf[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + // Rotate buffers for next iteration. + scf.yield %next_buf, %next_next_buf, %current_buf : memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space>, memref<64x64xf32, #gpu.address_space> + } + + return +} + // CHECK-LABEL: func.func @vector_ops_only func.func @vector_ops_only(%lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { %c0 = arith.constant 0 : index From 8d013b18487b15cd831138f2a5dca3d2ce766bd3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 13 Jan 2026 02:04:09 +0100 Subject: [PATCH 25/38] nicer print Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 3f3a5e6e74..abb022fc25 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -233,9 +233,10 @@ static SmallVector isStoreOp(Operation *op) { return result; } -template -static raw_ostream &print_range(raw_ostream &os, T &&range) { - llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); +static raw_ostream &print_range(raw_ostream &os, ValueRange range) { + llvm::interleaveComma(range, os, [&](Value item) { + item.print(os, OpPrintingFlags().skipRegions()); + }); return os; } @@ -747,14 +748,15 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LogicalResult visitOperation(Operation *op, const WaitcntState &before, WaitcntState *after) override { - LDBG() << "Visiting: " << *op; + LDBG() << "Visiting: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); LDBG() << " Before: " << before; // Start with the state before this operation WaitcntState newState = before; if (isBarrier(op)) { - LDBG() << " Barrier: " << *op; + LDBG() << " Barrier: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); newState.addPendingOp(op); LDBG() << " New state: " << newState; propagateIfChanged(after, after->join(newState)); @@ -826,7 +828,8 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { std::optional regionTo, const WaitcntState &before, WaitcntState *after) override { - LDBG() << "Visiting region branch control flow transfer: " << *branch; + LDBG() << "Visiting region branch control flow transfer: " + << OpWithFlags(branch, OpPrintingFlags().skipRegions()); LDBG() << " Region from: " << regionFrom; LDBG() << " Region to: " << regionTo; LDBG() << " Before: " << before; From f0173745f41f2d41e1a15cb36ed64322352fed11 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 13 Jan 2026 02:14:05 +0100 Subject: [PATCH 26/38] nicer print Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index abb022fc25..7bebca8b72 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -736,6 +736,10 @@ static RegionSuccessor getRegionResults(ArrayRef successors, llvm_unreachable("Region not found, malformed SCF op?"); } +static OpWithFlags withoutRegions(Operation *op) { + return OpWithFlags(op, OpPrintingFlags().skipRegions()); +} + /// Dense forward dataflow analysis for waitcnt insertion class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { public: @@ -748,15 +752,14 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LogicalResult visitOperation(Operation *op, const WaitcntState &before, WaitcntState *after) override { - LDBG() << "Visiting: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + LDBG() << "Visiting: " << withoutRegions(op); LDBG() << " Before: " << before; // Start with the state before this operation WaitcntState newState = before; if (isBarrier(op)) { - LDBG() << " Barrier: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); + LDBG() << " Barrier: " << withoutRegions(op); newState.addPendingOp(op); LDBG() << " New state: " << newState; propagateIfChanged(after, after->join(newState)); @@ -779,7 +782,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // newState.setRequirement(opRequirement); LDBG() << " Barriers found, requirement: " << opRequirement; for (Operation *barrier : barriers) { - LDBG() << " " << *barrier; + LDBG() << " " << withoutRegions(barrier); WaitcntState *beforeState = getOrCreate(getProgramPointBefore(barrier)); WaitcntState *afterState = @@ -793,7 +796,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Check if this is an existing memory_counter_wait operation if (auto waitOp = dyn_cast(op)) { - LDBG() << " Existing waitcnt operation: " << *waitOp; + LDBG() << " Existing waitcnt operation: " << withoutRegions(waitOp); opRequirement.merge(WaitcntRequirement(waitOp)); } @@ -809,7 +812,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Check if this is an async memory operation if (trackOp(op)) { // Add this operation to the pending list - LDBG() << " Adding pending operation: " << *op; + LDBG() << " Adding pending operation: " << withoutRegions(op); newState.addPendingOp(op); } @@ -829,7 +832,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { const WaitcntState &before, WaitcntState *after) override { LDBG() << "Visiting region branch control flow transfer: " - << OpWithFlags(branch, OpPrintingFlags().skipRegions()); + << withoutRegions(branch); LDBG() << " Region from: " << regionFrom; LDBG() << " Region to: " << regionTo; LDBG() << " Before: " << before; From 3616365ec164bec837678d0fe811571895db1646 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 13 Jan 2026 02:41:53 +0100 Subject: [PATCH 27/38] better barrier handling Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 15 ++++++++++--- water/test/Transforms/insert-waitcnt.mlir | 24 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 7bebca8b72..2c2496e247 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -254,8 +254,11 @@ struct PendingOperations { if (size() >= 256) llvm::report_fatal_error("Pending operations list is too long"); - if (!ops.empty() && isBarrier(op) && isBarrier(ops.back())) + // If we have c onsecutive barriers, keep only the later one. + if (!ops.empty() && isBarrier(op) && isBarrier(ops.back())) { + ops.back() = op; return opsTokens.back(); + } ops.push_back(op); auto &back = opsTokens.emplace_back(); @@ -520,13 +523,19 @@ class WaitcntState : public AbstractDenseLattice { SmallVector newPending; SmallVector newPendingTokens; WaitcntRequirement runningRequirement; - for (const auto &[op, tok] : llvm::reverse(pendingOps->opsAndTokens())) { + for (const auto &[op, tok] : pendingOps->opsAndTokensReverse()) { WaitcntRequirement opReq = WaitcntRequirement::getOperationRequirement(op, false); runningRequirement = runningRequirement + opReq; if (runningRequirement > requirement) continue; + // If we have consecutive barriers, skip the new one so only the later + // one is kept. + if (!newPending.empty() && isBarrier(op) && + isBarrier(newPending.back())) + continue; + newPending.push_back(op); newPendingTokens.push_back(tok); } @@ -778,7 +787,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LDBG() << " No memory dependency"; } - if (opRequirement.hasRequirement() && !barriers.empty()) { + if (!isBarrier(op) && opRequirement.hasRequirement() && !barriers.empty()) { // newState.setRequirement(opRequirement); LDBG() << " Barriers found, requirement: " << opRequirement; for (Operation *barrier : barriers) { diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 45cc882860..d47c48a19e 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -41,6 +41,30 @@ func.func @raw_dependency_vector_load(%global: memref<64x64xf32>, %lds: memref<6 return } +// CHECK-LABEL: func.func @raw_dependency_vector_load_2_barriers +func.func @raw_dependency_vector_load_2_barriers(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // Tensor load to LDS (write to LDS). + %base = amdgpu.make_dma_base %global[%c0, %c0], %lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + amdgpu.lds_barrier + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: amdgpu.lds_barrier + + // Wait only before last barrier + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK: amdgpu.lds_barrier + %vec = vector.load %lds[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} + // CHECK-LABEL: func.func @multiple_pending_ops func.func @multiple_pending_ops(%global: memref<64x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { %c0 = arith.constant 0 : index From 598edc45dc62cd7a87501c1cff6ec7b389de548f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 13 Jan 2026 14:06:48 +0100 Subject: [PATCH 28/38] remove redundant requirements Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 58 +++++++++++++++++++-- water/test/Transforms/insert-waitcnt.mlir | 24 +++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 2c2496e247..07980c9ec2 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -359,6 +359,15 @@ struct WaitcntRequirement { return changed; } + /// Join with another requirement, taking the maximum of the two. + WaitcntRequirement join(const WaitcntRequirement &other) const { + WaitcntRequirement result; + if (tensor_cnt.has_value() || other.tensor_cnt.has_value()) + result.tensor_cnt = + std::max(tensor_cnt.value_or(0), other.tensor_cnt.value_or(0)); + return result; + } + std::optional getTensorCnt() const { return tensor_cnt; } bool isSameCounterType(const WaitcntRequirement &other) const { @@ -582,13 +591,25 @@ class WaitcntState : public AbstractDenseLattice { void resetRequirement() { requirement = {}; } - /// Get the required waitcnt values + /// Get the required waitcnt values. const WaitcntRequirement &getRequirement() const { return requirement; } - /// Check if there's a waitcnt requirement + /// Check if there's a waitcnt requirement. bool hasRequirement() const { return requirement.hasRequirement(); } - /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait + /// Get all pending ops requirements. + WaitcntRequirement getAllPendingOpsRequirements() const { + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + for (Operation *op : pendingOps->ops) { + result = + result.join(WaitcntRequirement::getOperationRequirement(op, false)); + } + } + return result; + } + + /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait. WaitcntRequirement checkMemoryDependency(Operation *op, llvm::SmallSetVector &barriers) const { @@ -921,6 +942,7 @@ class WaterInsertWaitcntPass // Insert waitcnt operations based on analysis results IRRewriter rewriter(&getContext()); + LDBG() << "Inserting waitcnt operations"; op->walk([&](Operation *operation) { const WaitcntState *state = solver.lookupState( solver.getProgramPointAfter(operation)); @@ -928,6 +950,8 @@ class WaterInsertWaitcntPass return; const WaitcntRequirement &req = state->getRequirement(); + LDBG() << " Operation " << withoutRegions(operation) + << " requirement: " << req; auto getAttr = [&](std::optional cnt) -> IntegerAttr { if (!cnt.has_value()) @@ -943,6 +967,34 @@ class WaterInsertWaitcntPass nullptr, nullptr, nullptr, nullptr, getAttr(req.getTensorCnt())); }); + + // There may be a redundant requirement which were already added before the + // pass or when we merge control flow requiriment conservetively. Run the + // analysis again to detect them. + solver.eraseAllStates(); + + LDBG() << "Running analysis again to detect redundant waitcnts"; + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + LDBG() << "Checking redundant waitcnts"; + op->walk([&](amdgpu::MemoryCounterWaitOp wait) { + LDBG() << " Checking redundant waitcnt: " << withoutRegions(wait); + const WaitcntState *state = + solver.lookupState(solver.getProgramPointBefore(wait)); + if (!state) + return; + + WaitcntRequirement req = state->getAllPendingOpsRequirements(); + WaitcntRequirement opReq(wait); + LDBG() << " All pending ops requirements: " << req; + LDBG() << " Waitcnt requirement: " << opReq; + + // No requirement means no in flight ops, so it's noop. + // If max number if in flight ops is less than the requirement, it's noop. + if (!req.hasRequirement() || req < opReq) + rewriter.eraseOp(wait); + }); } }; diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index d47c48a19e..becf4a10ab 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -65,6 +65,30 @@ func.func @raw_dependency_vector_load_2_barriers(%global: memref<64x64xf32>, %ld return } +// CHECK-LABEL: func.func @raw_dependency_vector_load_redundant_waitcnt +func.func @raw_dependency_vector_load_redundant_waitcnt(%global: memref<64x64xf32>, %lds: memref<64x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + + // Tensor load to LDS (write to LDS). + %base = amdgpu.make_dma_base %global[%c0, %c0], %lds[%c0, %c0] : memref<64x64xf32>, memref<64x64xf32, #gpu.address_space> -> !amdgpu.tdm_base + %desc = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %desc : !amdgpu.tdm_descriptor + + // Barrier. + amdgpu.lds_barrier + + // This wait is redundant after we insert wait before barrier. + amdgpu.memory_counter_wait tensor(2) + + // Wait only before last barrier + // CHECK: amdgpu.memory_counter_wait tensor(0) + // CHECK: amdgpu.lds_barrier + // CHECK-NOT: amdgpu.memory_counter_wait + %vec = vector.load %lds[%c0, %c0] : memref<64x64xf32, #gpu.address_space>, vector<4xf32> + + return +} + // CHECK-LABEL: func.func @multiple_pending_ops func.func @multiple_pending_ops(%global: memref<64x64xf32>, %lds1: memref<64x64xf32, #gpu.address_space>, %lds2: memref<64x64xf32, #gpu.address_space>) { %c0 = arith.constant 0 : index From bca2456080334f5356574059e2b38b2c96b7a6ab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 19:34:23 +0100 Subject: [PATCH 29/38] adapt to api changes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 07980c9ec2..ef92855637 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -869,7 +869,11 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LDBG() << " After: " << *after; SmallVector successors; - branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + if (regionFrom) { + branch.getSuccessorRegions(branch->getRegion(*regionFrom), successors); + } else { + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + } auto destSuccessor = [&]() -> RegionSuccessor { if (regionTo) { @@ -880,7 +884,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } }(); // Dest values are either nested block args or branch op results. - ValueRange destValues = destSuccessor.getSuccessorInputs(); + ValueRange destValues = branch.getSuccessorInputs(destSuccessor); // Map from input values to dest values. llvm::SmallDenseMap valuesMapping; From ee1f0c3e7830da505d8bb23e4483e35864ab6bca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 19:39:55 +0100 Subject: [PATCH 30/38] skip tensor_cnt Signed-off-by: Ivan Butygin --- wave_lang/kernel/compiler/wave_codegen/handlers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 226a8f9fad..a526337254 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1758,7 +1758,8 @@ def handle_shared_memory_barrier_signal(emitter: WaveEmitter, node: fx.Node): except ValueError as e: raise ValidationError("Malformed arguments") from e - if tensor_wait: + # Water backend have a dedicated tensor waitcount insertion pass. + if tensor_wait and not emitter.options.use_water_backend: rocdl_d.s_wait_tensorcnt(0) if ds_wait and barId != CLUSTER_BARRIER_ID: From 659bbdff2d6c5adae950bf9a2d006f1747599055 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 19:56:43 +0100 Subject: [PATCH 31/38] fix tensor desc propagation Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 37 ++++++++------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index ef92855637..861a4aae97 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -172,8 +172,8 @@ static Value propagateViewOps(Value value) { return value; } -/// Collect all underlying values through view and select operations. -static SmallVector collectUnderlyingValues(Value value) { +/// Collect all underlying values through view, select and tensor operations. +static SmallVector collectUnderlyingValues(Value value, bool isLoad) { SmallVector result; SmallVector worklist; worklist.push_back(value); @@ -182,6 +182,11 @@ static SmallVector collectUnderlyingValues(Value value) { if (auto select = current.getDefiningOp()) { worklist.push_back(select.getTrueValue()); worklist.push_back(select.getFalseValue()); + } else if (auto makeDesc = + current.getDefiningOp()) { + worklist.push_back(makeDesc.getBase()); + } else if (auto makeBase = current.getDefiningOp()) { + worklist.push_back(isLoad ? makeBase.getGlobal() : makeBase.getLds()); } else { result.push_back(current); } @@ -194,30 +199,15 @@ static bool trackOp(Operation *op) { return isa(op); } -/// Try to get the base memref from the tensor descriptor. -static Value propagateTensorDesc(Value value, bool isLoad) { - auto makeDesc = value.getDefiningOp(); - if (!makeDesc) - return value; - - value = makeDesc.getBase(); - auto makeBase = value.getDefiningOp(); - if (!makeBase) - return value; - - return propagateViewOps(isLoad ? makeBase.getGlobal() : makeBase.getLds()); -} - /// Check if the operation is a load operation and return list of base memrefs /// to track. static SmallVector isLoadOp(Operation *op) { if (auto load = dyn_cast(op)) { - return collectUnderlyingValues(load.getBase()); + return collectUnderlyingValues(load.getBase(), true); } else if (auto load = dyn_cast(op)) { - return collectUnderlyingValues(load.getMemref()); + return collectUnderlyingValues(load.getMemref(), true); } else if (auto load = dyn_cast(op)) { - Value memref = propagateTensorDesc(load.getDesc(), true); - return collectUnderlyingValues(memref); + return collectUnderlyingValues(load.getDesc(), true); } return {}; } @@ -226,10 +216,9 @@ static SmallVector isLoadOp(Operation *op) { /// to track. static SmallVector isStoreOp(Operation *op) { SmallVector result; - if (auto store = dyn_cast(op)) { - Value memref = propagateTensorDesc(store.getDesc(), false); - return collectUnderlyingValues(memref); - } + if (auto store = dyn_cast(op)) + return collectUnderlyingValues(store.getDesc(), false); + return result; } From 95e665f8205ff4c753c5c7ec2822f9032edb6dd2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 20:01:19 +0100 Subject: [PATCH 32/38] update tensor count lowering Signed-off-by: Ivan Butygin --- wave_lang/kernel/compiler/wave_codegen/handlers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index a526337254..a133cf77b8 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1760,7 +1760,7 @@ def handle_shared_memory_barrier_signal(emitter: WaveEmitter, node: fx.Node): # Water backend have a dedicated tensor waitcount insertion pass. if tensor_wait and not emitter.options.use_water_backend: - rocdl_d.s_wait_tensorcnt(0) + amdgpu_d.memory_counter_wait(tensor=0) if ds_wait and barId != CLUSTER_BARRIER_ID: rocdl_d.s_wait_dscnt(0) @@ -1892,7 +1892,7 @@ def handle_tensor_counter_wait(emitter: WaveEmitter, node: fx.Node): except ValueError as e: raise ValidationError("Malformed arguments") from e - rocdl_d.s_wait_tensorcnt(count) + amdgpu_d.memory_counter_wait(tensor=count) @handle_op(workgroup_barrier) From a3b891afec5b42404d646286cf9d73672125351f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 20:06:44 +0100 Subject: [PATCH 33/38] disable manual tensor count in schedule Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/schedules/gemm_triple_buffer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py index 7ec46ca505..74e4ee5c14 100644 --- a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py @@ -243,7 +243,7 @@ def gfx1250_tbuf_gemm_schedule(): tkw.cluster( [ prologue_global_to_shared_fused, - tkw.TensorCounterWait(1), + # tkw.TensorCounterWait(1), tkw.SharedMemoryBarrierSignal(-1, ds_wait=False), tkw.SharedMemoryBarrierWait(-1), ], @@ -285,7 +285,7 @@ def gfx1250_tbuf_gemm_schedule(): loop_shared_load_b, # Barrier pattern after shared loads tkw.SetWavePrio(0), - tkw.TensorCounterWait(1), + # tkw.TensorCounterWait(1), tkw.SharedMemoryBarrierSignal(-1, ds_wait=True), tkw.SchedulingBarrier([]), tkw.SharedMemoryBarrierWait(-1), @@ -329,7 +329,7 @@ def gfx1250_tbuf_gemm_schedule(): tkw.cluster( [ # First set of loads (B and A together) - tkw.TensorCounterWait(1), + # tkw.TensorCounterWait(1), epilogue_shared_load_b_chunks[0], epilogue_shared_load_a_chunks[0], # Stagger barrier before first MMAs (no ds_wait) @@ -341,7 +341,7 @@ def gfx1250_tbuf_gemm_schedule(): epilogue_mma_chunks[0], # Stagger barrier before second loads tkw.SetWavePrio(1), - tkw.TensorCounterWait(0), + # tkw.TensorCounterWait(0), tkw.SharedMemoryBarrierSignal(-1, ds_wait=False), tkw.SchedulingBarrier([]), tkw.SharedMemoryBarrierWait(-1), From 6fe7081087fe6356b0bdab70a4784a34461ba0d5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 20:08:12 +0100 Subject: [PATCH 34/38] update test Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index 873d2118e0..3de43a3cee 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -3441,7 +3441,6 @@ def test_gfx1250_tbuf_gemm_codegen(use_water_backend: bool, tmp_path: Path): "s_wait_xcnt 0x0", "s_wait_kmcnt 0x0", "s_wait_tensorcnt 0x1", - "s_wait_tensorcnt 0x1", "s_wait_dscnt 0x0", "s_wait_tensorcnt 0x1", "s_wait_dscnt 0xe", From 6ab5f81ad0f192a4c1a019cabe4c802bfc9cd4a7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 20:12:44 +0100 Subject: [PATCH 35/38] test water backend Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index 3de43a3cee..ceee06bf34 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -3385,7 +3385,10 @@ def testSpecializeGemm( @require_gfx1250 @pytest.mark.parametrize("shape", [(1024, 1024, 1024)]) @pytest.mark.parametrize("mfma_variant", [MMAType.GFX1250_F32_16x16x32_F16]) -def test_gfx1250_tbuf_gemm(shape: tuple[int], mfma_variant: MMAType): +@use_water_backend_bool("use_water_backend") +def test_gfx1250_tbuf_gemm( + shape: tuple[int, int, int], mfma_variant: MMAType, use_water_backend: bool +): gemm, options = get_tagged_BxA_T_gemm( shape=shape, block_shape=(256, 256, 64), @@ -3397,6 +3400,7 @@ def test_gfx1250_tbuf_gemm(shape: tuple[int], mfma_variant: MMAType): schedule = get_gfx1250_tbuf_gemm_schedule() options = set_default_run_config(options) + options.use_water_backend = use_water_backend gemm = wave_compile(options, gemm, schedule) a = device_randn(shape[0], shape[2], dtype=torch.float16) From 41d011eec45ab243c2750218375aded45e7ce299 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 20:29:45 +0100 Subject: [PATCH 36/38] update lit tests Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/codegen.py | 2 +- lit_tests/kernel/wave/gemm.py | 6 +++--- lit_tests/kernel/wave/mma.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index c3d527eedb..c7524bde23 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1176,7 +1176,7 @@ def schedule_ops(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): print(schedule_ops.asm) # CHECK-LABEL: func.func @schedule_ops - # CHECK: rocdl.s.wait.tensorcnt 0 + # CHECK: amdgpu.memory_counter_wait tensor(0) # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 # CHECK: rocdl.s.barrier.wait id = -1 diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index ec3c0112c0..147965389a 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -1421,7 +1421,7 @@ def test_gemm_four_stage_global_to_lds(): # Verify prologue stores to shared memory # CHECK: amdgpu.tensor_load_to_lds - # CHECK: rocdl.s.wait.tensorcnt 0 + # CHECK: amdgpu.memory_counter_wait tensor(0) # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 # CHECK: rocdl.s.barrier.wait id = -1 @@ -1440,7 +1440,7 @@ def test_gemm_four_stage_global_to_lds(): # Verify WMMA exists # CHECK: rocdl.wmma.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}} - # CHECK: rocdl.s.wait.tensorcnt 0 + # CHECK: amdgpu.memory_counter_wait tensor(0) # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 # CHECK: rocdl.s.barrier.wait id = -1 @@ -1459,7 +1459,7 @@ def test_gemm_four_stage_global_to_lds(): # Epilogue: # CHECK: rocdl.wmma.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}} - # CHECK: rocdl.s.wait.tensorcnt 0 + # CHECK: amdgpu.memory_counter_wait tensor(0) # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 # CHECK: rocdl.s.barrier.wait id = -1 diff --git a/lit_tests/kernel/wave/mma.py b/lit_tests/kernel/wave/mma.py index 8a1c0f3ee1..fff0d4396e 100644 --- a/lit_tests/kernel/wave/mma.py +++ b/lit_tests/kernel/wave/mma.py @@ -684,7 +684,7 @@ def mma( ### resource provider # CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]] - # CHECK: rocdl.s.wait.tensorcnt 0 + # CHECK: amdgpu.memory_counter_wait tensor(0) # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 From 2e700cb7d2956ab5d88ebb9441066b51efb390cf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 21:08:51 +0100 Subject: [PATCH 37/38] optional tensor waitcount Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_test.py | 8 ++++++-- .../kernel/wave/schedules/gemm_triple_buffer.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index ceee06bf34..5a4dcbe6ab 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -3398,7 +3398,9 @@ def test_gfx1250_tbuf_gemm( compile_to_mlir=False, ) - schedule = get_gfx1250_tbuf_gemm_schedule() + schedule = get_gfx1250_tbuf_gemm_schedule( + insert_tensor_waitcount=not use_water_backend + ) options = set_default_run_config(options) options.use_water_backend = use_water_backend gemm = wave_compile(options, gemm, schedule) @@ -3424,7 +3426,9 @@ def test_gfx1250_tbuf_gemm_codegen(use_water_backend: bool, tmp_path: Path): compile_to_mlir=False, ) - schedule = get_gfx1250_tbuf_gemm_schedule() + schedule = get_gfx1250_tbuf_gemm_schedule( + insert_tensor_waitcount=not use_water_backend + ) options.target = "gfx1250" options.dump_intermediates = tmp_path options.use_water_backend = use_water_backend diff --git a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py index 74e4ee5c14..b39d7ed840 100644 --- a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py @@ -153,7 +153,7 @@ def async_two_cluster_three_stage_schedule(): return async_two_cluster_three_stage_schedule -def get_gfx1250_tbuf_gemm_schedule(): +def get_gfx1250_tbuf_gemm_schedule(insert_tensor_waitcount: bool): """ Returns a schedule function that implements a 3-stage pipelined prefetch with async global_to_shared operations, cluster-based reordering, and wave staggering. @@ -221,6 +221,11 @@ def gfx1250_tbuf_gemm_schedule(): ], ) + def tensor_waitcount(count: int) -> list[tkw.CustomOp]: + if insert_tensor_waitcount: + return [tkw.TensorCounterWait(count)] + return [] + # Filter nodes for PROLOGUE stage (before the loop) prologue_global_to_shared_fused = tkw.filter_nodes( tkw.get_node_by_tag("read_a,read_b"), subgraph=pipeline_loop.PROLOGUE @@ -243,7 +248,7 @@ def gfx1250_tbuf_gemm_schedule(): tkw.cluster( [ prologue_global_to_shared_fused, - # tkw.TensorCounterWait(1), + *tensor_waitcount(1), tkw.SharedMemoryBarrierSignal(-1, ds_wait=False), tkw.SharedMemoryBarrierWait(-1), ], @@ -285,7 +290,7 @@ def gfx1250_tbuf_gemm_schedule(): loop_shared_load_b, # Barrier pattern after shared loads tkw.SetWavePrio(0), - # tkw.TensorCounterWait(1), + *tensor_waitcount(1), tkw.SharedMemoryBarrierSignal(-1, ds_wait=True), tkw.SchedulingBarrier([]), tkw.SharedMemoryBarrierWait(-1), @@ -329,7 +334,7 @@ def gfx1250_tbuf_gemm_schedule(): tkw.cluster( [ # First set of loads (B and A together) - # tkw.TensorCounterWait(1), + *tensor_waitcount(1), epilogue_shared_load_b_chunks[0], epilogue_shared_load_a_chunks[0], # Stagger barrier before first MMAs (no ds_wait) @@ -341,7 +346,7 @@ def gfx1250_tbuf_gemm_schedule(): epilogue_mma_chunks[0], # Stagger barrier before second loads tkw.SetWavePrio(1), - # tkw.TensorCounterWait(0), + *tensor_waitcount(0), tkw.SharedMemoryBarrierSignal(-1, ds_wait=False), tkw.SchedulingBarrier([]), tkw.SharedMemoryBarrierWait(-1), From 8807ee48b07185c01b2b106ab1046c9d60cece71 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 29 Jan 2026 21:26:28 +0100 Subject: [PATCH 38/38] update test Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index 5a4dcbe6ab..e28a96f0a7 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -2874,7 +2874,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: asm = gemm.asm assert ( - "wait.tensorcnt" in asm + "memory_counter_wait tensor" in asm ), "tensor waitcnts are not found in asm: required for tensor load instructions." validate_gemm_result(a, b, c, options)