Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,

// Clean up attributes passing over schedules across stages in pipelining
void removePipeliningAttributes(ModuleOp moduleOp);

// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if
// they should be pipelined.
bool isPipeliningBeneficial(Operation *op,
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool filterSmall = true);

} // namespace triton
} // namespace mlir

Expand Down
61 changes: 1 addition & 60 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,64 +88,6 @@ class AssignLoadLatencies {
scf::ForOp forOp;
int numStages;
DenseMap<Operation *, int> &opLatency;

public:
static bool canHaveSharedEncoding(tt::LoadOp op) {
// If used by an user with DotOp encoding, all the uses must be compatible.
bool incompatible = false;
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
return !incompatible;
}

static bool
isPipeliningBeneficial(Operation *op, Operation *finalUser,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool filterSmall) {
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
LDBG("Load " << *loadOp << " is too small for pipelining");
return false;
}
}
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
return true;
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
LDBG("Load " << *op << " cannot have shared encoding");
return false;
}

ttg::SharedEncodingTrait localAllocEnc;
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<ttg::LocalAllocOp>(user);
})) {
for (auto user : op->getUsers()) {
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(user);
if (!localAlloc)
continue;
auto enc = mlir::cast<ttg::SharedEncodingTrait>(
localAlloc.getType().getEncoding());
if (!localAllocEnc) {
localAllocEnc = enc;
}
if (enc != localAllocEnc) {
// If the load is used by a LocalAllocOp, all the users need to have
// the same encoding.
return false;
}
}
}

if (localAllocEnc) {
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc);
if (filterSmall && vecBytes < 4) {
// At least 4 bytes need to be consecutive for cp.async
return false;
}
}

return true;
}
};

class AssignMMALatencies {
Expand Down Expand Up @@ -280,8 +222,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
if (!seen.insert(op).second || excluded.count(op))
return;
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
if (!AssignLoadLatencies::isPipeliningBeneficial(
op, finalUser, axisInfoAnalysis, filterSmall))
if (!isPipeliningBeneficial(op, axisInfoAnalysis, filterSmall))
return;
if (loadOpToIndLevel.count(op)) {
int level = loadOpToIndLevel[op].first;
Expand Down
27 changes: 9 additions & 18 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,26 +453,17 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
continue;
}
SharedEncodingTrait sharedEncoding;
bool canUseAsyncCp = false;
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
forOp.getContext(), 1, 1, 1, {0},
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
if (canUseAsyncCp) {
bool canUseAsyncCp =
triton::isPipeliningBeneficial(&op, axisInfoAnalysis);
if (canUseAsyncCp) {
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
forOp.getContext(), 1, 1, 1, {0},
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
scalarLoads.push_back(&op);
} else {
sharedEncoding = getSharedEncoding(&op);
}
} else {
sharedEncoding = getSharedEncoding(&op);
// Do not create async loads for small loads (cp.async requires at least
// 4 bytes)
canUseAsyncCp =
isa<tt::LoadOp>(op) &&
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
int copyVecBytes = getCopyVecBytes(
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);

canUseAsyncCp &= copyVecBytes >= 4;
}
if (canUseAsyncCp || isTMALoad(&op)) {
if (loadRequiresAdditionalBuffer(&op)) {
Expand Down
39 changes: 39 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,10 @@ ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) {
}

ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) {
if (!isa<RankedTensorType>(op->getResultTypes()[0])) {
return nullptr;
}

// Try to use local alloc encoding if possible.
ttg::SharedEncodingTrait localAllocEnc;
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
Expand Down Expand Up @@ -933,3 +937,38 @@ void triton::removePipeliningAttributes(ModuleOp moduleOp) {
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
});
}

static bool canHaveSharedEncoding(tt::LoadOp op) {
// If used by an user with DotOp encoding, all the uses must be compatible.
bool incompatible = false;
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
return !incompatible;
}

bool triton::isPipeliningBeneficial(
Operation *op, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool filterSmall) {
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
LDBG("Load " << *loadOp << " is too small for pipelining");
return false;
}
}
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
return true;
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
LDBG("Load " << *op << " cannot have shared encoding");
return false;
}

if (auto localAllocEnc = getSharedEncoding(op)) {
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
auto vecBytes = mlir::triton::getCopyVecBytes(registerTy, localAllocEnc);
if (filterSmall && vecBytes < 4) {
// At least 4 bytes need to be consecutive for cp.async
return false;
}
}

return true;
}
24 changes: 16 additions & 8 deletions python/examples/gluon/01-attention-forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ class BarrierCounter:
phase: gl.tensor
num_barriers: gl.constexpr

@gluon.constexpr_function
def __init__(self, index, phase, num_barriers):
self.index = index
self.phase = phase
self.num_barriers = num_barriers
self.num_barriers = gl.constexpr(num_barriers)

@gluon.must_use_result
@gluon.jit
Expand All @@ -79,6 +80,7 @@ class ChannelType:
num_buffers: gl.constexpr
num_consumers: gl.constexpr

@gluon.constexpr_function
def __init__(self, mem, ready_bars, empty_bars, num_buffers, num_consumers):
self.mem = mem
self.ready_bars = ready_bars
Expand Down Expand Up @@ -143,6 +145,7 @@ class Producer:
channel: ChannelType
counter: BarrierCounter

@gluon.constexpr_function
def __init__(self, channel, counter):
self.channel = channel
self.counter = counter
Expand All @@ -158,6 +161,7 @@ class Consumer:
channel: ChannelType
counter: BarrierCounter

@gluon.constexpr_function
def __init__(self, channel, counter):
self.channel = channel
self.counter = counter
Expand Down Expand Up @@ -234,6 +238,7 @@ class AttentionConfig:
num_kv_buffers: gl.constexpr
use_exp2_turnstile: gl.constexpr

@gluon.constexpr_function
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype,
num_warps):
self.qk_scale = qk_scale
Expand All @@ -250,7 +255,7 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
self.num_warps = gl.constexpr(num_warps)

self.SPLIT_D_FACTOR = gl.constexpr(2)
self.SPLIT_EXP_FACTOR = 256 // HEAD_DIM
self.SPLIT_EXP_FACTOR = gl.constexpr(256 // HEAD_DIM)
self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)
Expand Down Expand Up @@ -305,6 +310,7 @@ class ProgramScheduler:
num_pid_in_group: gl.tensor
num_tiles: gl.tensor

@gluon.constexpr_function
def __init__(self, config, start_pid, num_pid_n, num_pid_in_group, num_tiles):
self.config = config
self.start_pid = start_pid
Expand Down Expand Up @@ -339,6 +345,7 @@ class AttentionProgram:
offset_y: gl.tensor
qo_offset_y: gl.tensor

@gluon.constexpr_function
def __init__(self, config, start_m, off_hz, offset_y, qo_offset_y):
self.config = config
self.start_m = start_m
Expand Down Expand Up @@ -840,12 +847,13 @@ def attention_kernel( #

chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
descs = (desc_q, desc_k, desc_v, desc_o)
gl.warp_specialize((config, chnls, descs, M, STAGE), _attn_fwd_correction, (config, chnls, descs, M, STAGE), [
_attn_fwd_softmax0,
_attn_fwd_softmax1,
_attn_fwd_mma,
_attn_fwd_load,
_attn_fwd_epilogue,
gl.warp_specialize([
(_attn_fwd_correction, (config, chnls, descs, M, STAGE)),
(_attn_fwd_softmax0, (config, chnls, descs, M, STAGE)),
(_attn_fwd_softmax1, (config, chnls, descs, M, STAGE)),
(_attn_fwd_mma, (config, chnls, descs, M, STAGE)),
(_attn_fwd_load, (config, chnls, descs, M, STAGE)),
(_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)),
], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24])

q_chnl.release()
Expand Down
19 changes: 19 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,16 @@ void init_gluon_ir(py::module &&m) {
ctx, version, warpsPerCta, instrShape, transposed, ctaLayout,
tilesPerWarp, elementBitWidth);
})
.def("get_amd_mfma_scale_layout",
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
unsigned mfmaMDim, std::vector<unsigned> &tilesPerWarp,
std::vector<unsigned> &warpsPerCTA) -> py::object {
auto ctx = self.getContext();
auto ll = ttg::chooseScaledMfmaScaleLayout(
ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
return layoutToGluon(attr);
})
.def("get_amd_wmma_layout",
[](GluonOpBuilder &self, unsigned version, bool transposed,
std::vector<unsigned> &warpsPerCta,
Expand All @@ -397,6 +407,15 @@ void init_gluon_ir(py::module &&m) {
return ttg::AMDWmmaEncodingAttr::get(
ctx, version, transposed, warpsPerCta, ctaLayout, instrShape);
})
.def("get_amd_wmma_scale_layout",
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
std::vector<unsigned> &warpsPerCTA) -> py::object {
auto ctx = self.getContext();
auto ll = ttg::chooseScaledWmmaScaleLayout(ctx, opIdx, warpsPerCTA,
shape);
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
return layoutToGluon(attr);
})
.def("get_intel_dpas_layout",
[](GluonOpBuilder &self, unsigned repeatCount,
unsigned systolicDepth, unsigned executionSize,
Expand Down
22 changes: 16 additions & 6 deletions python/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class TritonOpBuilder {
if (!block.empty())
setLastLoc(block.begin()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
setLastLoc(getLocForBlock(&block));
builder->setInsertionPointToStart(&block);
}

void setInsertionPointToEnd(mlir::Block &block) {
if (!block.empty())
setLastLoc(block.back().getLoc());
else
setLastLoc(builder->getUnknownLoc());
setLastLoc(getLocForBlock(&block));
builder->setInsertionPointToEnd(&block);
}

Expand All @@ -53,10 +53,14 @@ class TritonOpBuilder {
}

void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) {
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
setLastLoc(pt.getPoint()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
setLastLoc(builder->getUnknownLoc());
if (pt.isSet()) {
if (pt.getPoint() != pt.getBlock()->end())
setLastLoc(pt.getPoint()->getLoc());
else
setLastLoc(getLocForBlock(pt.getBlock()));
}

builder->restoreInsertionPoint(pt);
}

Expand Down Expand Up @@ -87,4 +91,10 @@ class TritonOpBuilder {
std::unique_ptr<mlir::Location> lastLoc;
bool lineInfoEnabled =
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");

mlir::Location getLocForBlock(mlir::Block *block) {
if (auto parentOp = block->getParentOp())
return parentOp->getLoc();
return builder->getUnknownLoc();
}
};
Loading