@@ -157,19 +157,22 @@ class LayoutRematerialization {
157157 getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
158158 SetVector<Value> &slice,
159159 DenseMap<Value, Attribute> &layout,
160- std::function<bool (Operation *)> stopPropagation);
160+ std::function<bool (Operation *)> stopPropagation,
161+ bool includeForOp = false );
161162
162163 LogicalResult getRematerializableSlice (
163164 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
164165 DenseMap<Value, Attribute> &layout,
165- std::function<bool (Operation *)> stopPropagation = nullptr);
166+ std::function<bool (Operation *)> stopPropagation = nullptr,
167+ bool includeForOp = false);
166168
167169private:
168170 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
169171 // Existing tuples of (value, layout) that needs to be updated when recreating
170172 // scf ops. This prevents keeping track of Values that have been delete when
171- // rewriting slices.
172- DenseMap<Value, Attribute> mappedValues;
173+ // rewriting slices. The Value maybe mapped to different attributes in remove
174+ // layout.
175+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
173176 // map of the values remat based on encoding.
174177 DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
175178 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -183,7 +186,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
183186 Value newV) {
184187 LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
185188 rematMapping[{old, encoding}] = newV;
186- mappedValues[old] = encoding;
189+ if (mappedValues.contains (old)) {
190+ mappedValues[old].push_back (encoding);
191+ } else {
192+ mappedValues[old] = {encoding};
193+ }
187194}
188195
189196// Remove unneeded values now that we are done with the rematMapping.
@@ -988,22 +995,28 @@ void LayoutRematerialization::updateRematMapping(
988995 for (auto [old, newV] : values) {
989996 auto it = mappedValues.find (old);
990997 if (it != mappedValues.end ()) {
991- Attribute encoding = it->second ;
992- auto rematIt = rematMapping.find ({old, it->second });
993- assert (rematIt != rematMapping.end ());
994- Value replacedValue = rematIt->second ;
995- rematMapping.erase (rematIt);
996- mappedValues.erase (it);
997- // Loop through the replacement value to find the new version of remat
998- // value. This should be okay as the number of values should be small.
999- for (auto [before, after] : values) {
1000- if (before == replacedValue) {
1001- replacedValue = after;
1002- break ;
998+ SmallVector<Attribute> encodings = it->second ;
999+ for (auto encoding : encodings) {
1000+ auto rematIt = rematMapping.find ({old, encoding});
1001+ assert (rematIt != rematMapping.end ());
1002+ Value replacedValue = rematIt->second ;
1003+ rematMapping.erase (rematIt);
1004+ // Loop through the replacement value to find the new version of remat
1005+ // value. This should be okay as the number of values should be small.
1006+ for (auto [before, after] : values) {
1007+ if (before == replacedValue) {
1008+ replacedValue = after;
1009+ break ;
1010+ }
10031011 }
1012+ rematMapping[{newV, encoding}] = replacedValue;
1013+ }
1014+ mappedValues.erase (it);
1015+ if (mappedValues.contains (newV)) {
1016+ mappedValues[newV].append (encodings);
1017+ } else {
1018+ mappedValues[newV] = std::move (encodings);
10041019 }
1005- rematMapping[{newV, encoding}] = replacedValue;
1006- mappedValues[newV] = encoding;
10071020 }
10081021 }
10091022}
@@ -1078,6 +1091,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10781091 deadOps.push_back (forOp.getOperation ());
10791092 Block &loopBody = *newForOp.getBody ();
10801093 for (auto m : argMapping) {
1094+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10811095 mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10821096 int numIndVars = newForOp.getNumInductionVars ();
10831097 mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1188,8 +1202,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11881202 builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11891203 }
11901204
1191- for (Operation *op : deadOps)
1192- opToDelete.insert (op);
1205+ for (Operation *op : deadOps) {
1206+ if (!isa<scf::ForOp>(op))
1207+ opToDelete.insert (op);
1208+ else
1209+ op->erase ();
1210+ }
11931211}
11941212
11951213void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1202,7 +1220,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
12021220LogicalResult LayoutRematerialization::getConvertBackwardSlice (
12031221 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12041222 DenseMap<Value, Attribute> &layout,
1205- std::function<bool (Operation *)> stopPropagation) {
1223+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
12061224 // Allow re-using existing conversions for a value. Check dominance of any
12071225 // reusable materializations against the root value. This is sufficient
12081226 // because the conversions are processed in post-order.
@@ -1231,15 +1249,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12311249 };
12321250
12331251 return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1234- stopPropagation, getExistingConversion);
1252+ stopPropagation, getExistingConversion,
1253+ includeForOp);
12351254}
12361255
12371256LogicalResult LayoutRematerialization::getRematerializableSlice (
12381257 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12391258 DenseMap<Value, Attribute> &layout,
1240- std::function<bool (Operation *)> stopPropagation) {
1241- LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1242- layout, stopPropagation);
1259+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1260+ LogicalResult result = getConvertBackwardSlice (
1261+ root, rootEncoding, slice, layout, stopPropagation, includeForOp );
12431262 if (result.failed () || slice.empty ())
12441263 return failure ();
12451264
@@ -1362,8 +1381,9 @@ void LayoutRematerialization::backwardRematerialization(
13621381 // rematerialized.
13631382 SetVector<Value> slice;
13641383 DenseMap<Value, Attribute> layout;
1365- LogicalResult result = getRematerializableSlice (
1366- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1384+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1385+ targetType.getEncoding (),
1386+ slice, layout, nullptr , true );
13671387 if (result.failed ()) {
13681388 LDBG (" getRematerializableSlice failed" );
13691389 return ;
0 commit comments