11#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
22#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
33#include " triton/Dialect/Triton/IR/Utility.h"
4+ #include " triton/Dialect/TritonGPU/Transforms/Utility.h"
45#include " llvm/ADT/PriorityWorklist.h"
56
67namespace ttg = mlir::triton::gpu;
@@ -16,45 +17,11 @@ namespace gpu::intel {
1617
1718namespace {
1819
19- SmallVector<Value> getTiedArgs (Operation *op, int resultIdx) {
20- if (auto forOp = dyn_cast<scf::ForOp>(op)) {
21- auto iterArg = forOp.getRegionIterArg (resultIdx);
22- auto result = forOp.getResult (resultIdx);
23- auto yieldVal = forOp.getBody ()->getTerminator ()->getOperand (resultIdx);
24- auto initVal = forOp.getInitArgs ()[resultIdx];
25- return {iterArg, result, yieldVal, initVal};
26- } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
27- auto iterArg = whileOp.getBeforeArguments ()[resultIdx];
28- auto result = whileOp.getResults ()[resultIdx];
29- auto yieldVal = whileOp.getConditionOp ().getArgs ()[resultIdx];
30- auto initVal = whileOp.getOperands ()[resultIdx];
31- auto bodyArg = whileOp.getAfterArguments ()[resultIdx];
32- return {iterArg, result, yieldVal, initVal, bodyArg};
33- } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
34- SmallVector<Value> values;
35- for (auto &block : ifOp.getThenRegion ().getBlocks ()) {
36- auto terminator = block.getTerminator ();
37- if (isa<scf::YieldOp>(terminator))
38- values.push_back (terminator->getOperands ()[resultIdx]);
39- }
40- for (auto &block : ifOp.getElseRegion ().getBlocks ()) {
41- auto terminator = block.getTerminator ();
42- if (isa<scf::YieldOp>(terminator))
43- values.push_back (terminator->getOperands ()[resultIdx]);
44- }
45- values.push_back (ifOp->getResults ()[resultIdx]);
46- return values;
47- }
48- return {};
49- }
50-
5120struct EncodingInfo {
5221 Attribute desiredEncoding;
53- bool requiresConvert = false ;
5422
5523 bool operator ==(const EncodingInfo &other) const {
56- return desiredEncoding == other.desiredEncoding &&
57- requiresConvert == other.requiresConvert ;
24+ return desiredEncoding == other.desiredEncoding ;
5825 }
5926};
6027
@@ -77,10 +44,6 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
7744
7845 auto updateEncoding = [&](ArrayRef<Value> ptrValues, EncodingInfo info) {
7946 for (auto value : ptrValues) {
80- bool requiresConvert = llvm::any_of (
81- value.getUsers (), [](auto user) { return isa<LoadOp>(user); });
82- info.requiresConvert = requiresConvert;
83-
8447 auto typedVal = cast<TypedValue<PointerType>>(value);
8548 auto itr = valueToEncodingInfo.find (typedVal);
8649 if (itr == valueToEncodingInfo.end ()) {
@@ -157,24 +120,22 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
157120 oldTensorTy.getShape (), oldTensorTy.getElementType (), newEncoding);
158121
159122 val.setType (PointerType::get (newTensorTy, oldType.getAddressSpace ()));
160- if (einfo.requiresConvert ) {
161- for (auto user : val.getUsers ()) {
162- if (auto loadOp = dyn_cast<LoadOp>(user)) {
163-
164- OpBuilder builder (loadOp);
165- auto oldLoadType = loadOp.getType ();
166- Value result = loadOp.getResult ();
167-
168- builder.setInsertionPointAfter (loadOp);
169- auto cvt = builder.create <ConvertLayoutOp>(loadOp.getLoc (),
170- result.getType (), result);
171- LLVM_DEBUG (DBGS () << " Added convert Op:\n "
172- << cvt << " after Load Op:\n "
173- << loadOp << " \n " );
174- result.setType (newTensorTy);
175-
176- result.replaceAllUsesExcept (cvt.getResult (), cvt.getOperation ());
177- }
123+ for (auto user : val.getUsers ()) {
124+ if (auto loadOp = dyn_cast<LoadOp>(user)) {
125+
126+ OpBuilder builder (loadOp);
127+ auto oldLoadType = loadOp.getType ();
128+ Value result = loadOp.getResult ();
129+
130+ builder.setInsertionPointAfter (loadOp);
131+ auto cvt = builder.create <ConvertLayoutOp>(loadOp.getLoc (),
132+ result.getType (), result);
133+ LLVM_DEBUG (DBGS () << " Added convert Op:\n "
134+ << cvt << " after Load Op:\n "
135+ << loadOp << " \n " );
136+ result.setType (newTensorTy);
137+
138+ result.replaceAllUsesExcept (cvt.getResult (), cvt.getOperation ());
178139 }
179140 }
180141 }
0 commit comments