@@ -149,19 +149,22 @@ class LayoutRematerialization {
149149 getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
150150 SetVector<Value> &slice,
151151 DenseMap<Value, Attribute> &layout,
152- std::function<bool (Operation *)> stopPropagation);
152+ std::function<bool (Operation *)> stopPropagation,
153+ bool includeForOp = false );
153154
154155 LogicalResult getRematerializableSlice (
155156 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
156157 DenseMap<Value, Attribute> &layout,
157- std::function<bool (Operation *)> stopPropagation = nullptr);
158+ std::function<bool (Operation *)> stopPropagation = nullptr,
159+ bool includeForOp = false);
158160
159161private:
160162 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
161163 // Existing tuples of (value, layout) that needs to be updated when recreating
162164 // scf ops. This prevents keeping track of Values that have been delete when
163- // rewriting slices.
164- DenseMap<Value, Attribute> mappedValues;
165+ // rewriting slices. The Value maybe mapped to different attributes in remove
166+ // layout.
167+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
165168 // map of the values remat based on encoding.
166169 DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
167170 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
174177 Value newV) {
175178 LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
176179 rematMapping[{old, encoding}] = newV;
177- mappedValues[old] = encoding;
180+ if (mappedValues.contains (old)) {
181+ mappedValues[old].push_back (encoding);
182+ } else {
183+ mappedValues[old] = {encoding};
184+ }
178185}
179186
180187// Remove unneeded values now that we are done with the rematMapping.
@@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping(
955962 for (auto [old, newV] : values) {
956963 auto it = mappedValues.find (old);
957964 if (it != mappedValues.end ()) {
958- Attribute encoding = it->second ;
959- auto rematIt = rematMapping.find ({old, it->second });
960- assert (rematIt != rematMapping.end ());
961- Value replacedValue = rematIt->second ;
962- rematMapping.erase (rematIt);
963- mappedValues.erase (it);
964- // Loop through the replacement value to find the new version of remat
965- // value. This should be okay as the number of values should be small.
966- for (auto [before, after] : values) {
967- if (before == replacedValue) {
968- replacedValue = after;
969- break ;
965+ SmallVector<Attribute> encodings = it->second ;
966+ for (auto encoding : encodings) {
967+ auto rematIt = rematMapping.find ({old, encoding});
968+ assert (rematIt != rematMapping.end ());
969+ Value replacedValue = rematIt->second ;
970+ rematMapping.erase (rematIt);
971+ // Loop through the replacement value to find the new version of remat
972+ // value. This should be okay as the number of values should be small.
973+ for (auto [before, after] : values) {
974+ if (before == replacedValue) {
975+ replacedValue = after;
976+ break ;
977+ }
970978 }
979+ rematMapping[{newV, encoding}] = replacedValue;
980+ }
981+ mappedValues.erase (it);
982+ if (mappedValues.contains (newV)) {
983+ mappedValues[newV].append (encodings);
984+ } else {
985+ mappedValues[newV] = std::move (encodings);
971986 }
972- rematMapping[{newV, encoding}] = replacedValue;
973- mappedValues[newV] = encoding;
974987 }
975988 }
976989}
@@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10451058 deadOps.push_back (forOp.getOperation ());
10461059 Block &loopBody = *newForOp.getBody ();
10471060 for (auto m : argMapping) {
1061+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10481062 mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10491063 int numIndVars = newForOp.getNumInductionVars ();
10501064 mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11611175 builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11621176 }
11631177
1164- for (Operation *op : deadOps)
1165- opToDelete.insert (op);
1178+ for (Operation *op : deadOps) {
1179+ if (!isa<scf::ForOp>(op))
1180+ opToDelete.insert (op);
1181+ else
1182+ op->erase ();
1183+ }
11661184}
11671185
11681186void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11751193LogicalResult LayoutRematerialization::getConvertBackwardSlice (
11761194 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
11771195 DenseMap<Value, Attribute> &layout,
1178- std::function<bool (Operation *)> stopPropagation) {
1196+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
11791197 // Allow re-using existing conversions for a value. Check dominance of any
11801198 // reusable materializations against the root value. This is sufficient
11811199 // because the conversions are processed in post-order.
@@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12041222 };
12051223
12061224 return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1207- stopPropagation, getExistingConversion);
1225+ stopPropagation, getExistingConversion,
1226+ includeForOp);
12081227}
12091228
12101229LogicalResult LayoutRematerialization::getRematerializableSlice (
12111230 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121231 DenseMap<Value, Attribute> &layout,
1213- std::function<bool (Operation *)> stopPropagation) {
1214- LogicalResult result = getConvertBackwardSlice (
1215- root, rootEncoding, slice, layout, std::move (stopPropagation));
1232+ std::function<bool (Operation *)> stopPropagation, bool includeForOp) {
1233+
1234+ LogicalResult result =
1235+ getConvertBackwardSlice (root, rootEncoding, slice, layout,
1236+ std::move (stopPropagation), includeForOp);
12161237 if (result.failed () || slice.empty ())
12171238 return failure ();
12181239
@@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization(
13011322 // rematerialized.
13021323 SetVector<Value> slice;
13031324 DenseMap<Value, Attribute> layout;
1304- LogicalResult result = getRematerializableSlice (
1305- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1325+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1326+ targetType.getEncoding (),
1327+ slice, layout, nullptr , true );
13061328 if (result.failed ()) {
13071329 LDBG (" getRematerializableSlice failed" );
13081330 return ;
0 commit comments