Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ class LayoutRematerialization {
void reduceLoopCarriedValues();
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// rewriting slices. The Value maybe mapped to different attributes in remove
// layout.
DenseMap<Value, SmallVector<Attribute>> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -185,7 +186,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
if (mappedValues.contains(old))
mappedValues[old].push_back(encoding);
else
mappedValues[old] = {encoding};
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -990,22 +994,27 @@ void LayoutRematerialization::updateRematMapping(
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
SmallVector<Attribute> encodings = it->second;
for (Attribute encoding : encodings) {
auto rematIt = rematMapping.find({old, encoding});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
mappedValues.erase(it);
if (mappedValues.contains(newV))
mappedValues[newV].append(encodings);
else
mappedValues[newV] = std::move(encodings);
}
}
}
Expand Down Expand Up @@ -1309,8 +1318,9 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}

for (Operation *op : deadOps)
for (Operation *op : deadOps) {
opToDelete.insert(op);
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Expand Down
Loading