@@ -149,19 +149,22 @@ class LayoutRematerialization {
149149 getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
150150 SetVector<Value> &slice,
151151 DenseMap<Value, Attribute> &layout,
152- std::function<bool (Operation *)> stopPropagation);
152+ std::function<bool (Operation *)> stopPropagation,
153+ bool includeForOp = false );
153154
154155 LogicalResult getRematerializableSlice (
155156 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
156157 DenseMap<Value, Attribute> &layout,
157- std::function<bool (Operation *)> stopPropagation = nullptr);
158+ std::function<bool (Operation *)> stopPropagation = nullptr,
159+ bool includeForOp = false);
158160
159161private:
160162 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
161163 // Existing tuples of (value, layout) that needs to be updated when recreating
162164 // scf ops. This prevents keeping track of Values that have been delete when
163- // rewriting slices.
164- DenseMap<Value, Attribute> mappedValues;
165+ // rewriting slices. The Value maybe mapped to different attributes in remove
166+ // layout.
167+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
165168 // map of the values remat based on encoding.
166169 DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
167170 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
174177 Value newV) {
175178 LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
176179 rematMapping[{old, encoding}] = newV;
177- mappedValues[old] = encoding;
180+ if (mappedValues.contains (old)) {
181+ mappedValues[old].push_back (encoding);
182+ } else {
183+ mappedValues[old] = {encoding};
184+ }
178185}
179186
180187// Remove unneeded values now that we are done with the rematMapping.
@@ -955,22 +962,30 @@ void LayoutRematerialization::updateRematMapping(
955962 for (auto [old, newV] : values) {
956963 auto it = mappedValues.find (old);
957964 if (it != mappedValues.end ()) {
958- Attribute encoding = it->second ;
959- auto rematIt = rematMapping.find ({old, it->second });
960- assert (rematIt != rematMapping.end ());
961- Value replacedValue = rematIt->second ;
962- rematMapping.erase (rematIt);
963- mappedValues.erase (it);
964- // Loop through the replacement value to find the new version of remat
965- // value. This should be okay as the number of values should be small.
966- for (auto [before, after] : values) {
967- if (before == replacedValue) {
968- replacedValue = after;
969- break ;
965+ SmallVector<Attribute> encodings = it->second ;
966+ for (auto encoding : encodings) {
967+ Attribute newEncoding =
968+ cast<RankedTensorType>(newV.getType ()).getEncoding ();
969+ auto rematIt = rematMapping.find ({old, encoding});
970+ assert (rematIt != rematMapping.end ());
971+ Value replacedValue = rematIt->second ;
972+ rematMapping.erase (rematIt);
973+ // Loop through the replacement value to find the new version of remat
974+ // value. This should be okay as the number of values should be small.
975+ for (auto [before, after] : values) {
976+ if (before == replacedValue) {
977+ replacedValue = after;
978+ break ;
979+ }
970980 }
981+ rematMapping[{newV, encoding}] = replacedValue;
982+ }
983+ mappedValues.erase (it);
984+ if (mappedValues.contains (newV)) {
985+ mappedValues[newV].append (encodings);
986+ } else {
987+ mappedValues[newV] = std::move (encodings);
971988 }
972- rematMapping[{newV, encoding}] = replacedValue;
973- mappedValues[newV] = encoding;
974989 }
975990 }
976991}
@@ -1045,6 +1060,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10451060 deadOps.push_back (forOp.getOperation ());
10461061 Block &loopBody = *newForOp.getBody ();
10471062 for (auto m : argMapping) {
1063+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10481064 mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10491065 int numIndVars = newForOp.getNumInductionVars ();
10501066 mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1161,8 +1177,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11611177 builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11621178 }
11631179
1164- for (Operation *op : deadOps)
1165- opToDelete.insert (op);
1180+ for (Operation *op : deadOps) {
1181+ if (!isa<scf::ForOp>(op))
1182+ opToDelete.insert (op);
1183+ else
1184+ op->erase ();
1185+ }
11661186}
11671187
11681188void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1175,7 +1195,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11751195LogicalResult LayoutRematerialization::getConvertBackwardSlice (
11761196 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
11771197 DenseMap<Value, Attribute> &layout,
1178- std::function<bool (Operation *)> stopPropagation) {
1198+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
11791199 // Allow re-using existing conversions for a value. Check dominance of any
11801200 // reusable materializations against the root value. This is sufficient
11811201 // because the conversions are processed in post-order.
@@ -1204,15 +1224,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12041224 };
12051225
12061226 return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1207- stopPropagation, getExistingConversion);
1227+ stopPropagation, getExistingConversion,
1228+ includeForOp);
12081229}
12091230
12101231LogicalResult LayoutRematerialization::getRematerializableSlice (
12111232 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121233 DenseMap<Value, Attribute> &layout,
1213- std::function<bool (Operation *)> stopPropagation) {
1214- LogicalResult result = getConvertBackwardSlice (
1215- root, rootEncoding, slice, layout, std::move (stopPropagation));
1234+ std::function<bool (Operation *)> stopPropagation, bool includeForOp) {
1235+
1236+ LogicalResult result =
1237+ getConvertBackwardSlice (root, rootEncoding, slice, layout,
1238+ std::move (stopPropagation), includeForOp);
12161239 if (result.failed () || slice.empty ())
12171240 return failure ();
12181241
@@ -1301,8 +1324,9 @@ void LayoutRematerialization::backwardRematerialization(
13011324 // rematerialized.
13021325 SetVector<Value> slice;
13031326 DenseMap<Value, Attribute> layout;
1304- LogicalResult result = getRematerializableSlice (
1305- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1327+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1328+ targetType.getEncoding (),
1329+ slice, layout, nullptr , true );
13061330 if (result.failed ()) {
13071331 LDBG (" getRematerializableSlice failed" );
13081332 return ;
0 commit comments