Skip to content

Commit 42bb785

Browse files
allightcopybara-github
authored andcommitted
Optimize node dependency intermediates.
We aggressively prune the intermediate values as we go. This leads to less memory usage. PiperOrigin-RevId: 811977383
1 parent 10a0e57 commit 42bb785

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

xls/passes/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,8 @@ cc_library(
13421342
"@com_google_absl//absl/base:core_headers",
13431343
"@com_google_absl//absl/container:flat_hash_map",
13441344
"@com_google_absl//absl/container:flat_hash_set",
1345+
"@com_google_absl//absl/log",
1346+
"@com_google_absl//absl/log:check",
13451347
"@com_google_absl//absl/status",
13461348
"@com_google_absl//absl/status:statusor",
13471349
"@com_google_absl//absl/types:span",

xls/passes/node_dependency_analysis.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include "absl/container/flat_hash_map.h"
2222
#include "absl/container/flat_hash_set.h"
23+
#include "absl/log/check.h"
24+
#include "absl/log/log.h"
2325
#include "absl/status/status.h"
2426
#include "absl/status/statusor.h"
2527
#include "absl/types/span.h"
@@ -35,14 +37,14 @@ namespace {
3537
// Perform actual analysis.
3638
// f is the function to analyze. We only care about getting results for
3739
// 'interesting_nodes' if the set is non-empty (otherwise all nodes are
38-
// searched). Preds returns the nodes the argument depends on. iter is the
40+
// searched). Succs returns the nodes that depend on the argument. iter is the
3941
// iterator to walk the function in the topological order defined by preds.
40-
template <typename Predecessors>
42+
template <typename Successors>
4143
std::tuple<absl::flat_hash_map<Node*, InlineBitmap>,
4244
absl::flat_hash_map<Node*, int64_t>>
4345
AnalyzeDependents(FunctionBase* f,
4446
const absl::flat_hash_set<Node*>& interesting_nodes,
45-
Predecessors preds, absl::Span<Node* const> topo_sort) {
47+
Successors succs, absl::Span<Node* const> topo_sort) {
4648
absl::flat_hash_map<Node*, int64_t> node_ids;
4749
node_ids.reserve(f->node_count());
4850
int64_t cnt = 0;
@@ -62,16 +64,24 @@ AnalyzeDependents(FunctionBase* f,
6264
}
6365
return seen_interesting_nodes_count == interesting_nodes.size();
6466
};
67+
VLOG(3) << "Analyzing dependents of " << f->node_count() << " nodes with "
68+
<< interesting_nodes.size() << " interesting.";
6569
int64_t bitmap_size = f->node_count();
6670
absl::flat_hash_map<Node*, InlineBitmap> results;
6771
results.reserve(f->node_count());
6872
for (Node* n : topo_sort) {
69-
InlineBitmap& bm = results.emplace(n, bitmap_size).first->second;
73+
auto [it, inserted] = results.try_emplace(n, bitmap_size);
74+
InlineBitmap& bm = it->second;
7075
bm.Set(node_ids[n]);
71-
for (Node* pred : preds(n)) {
72-
bm.Union(results.at(pred));
76+
for (Node* succ : succs(n)) {
77+
auto [s_it, s_new] = results.try_emplace(succ, bm);
78+
if (!s_new) {
79+
s_it->second.Union(bm);
80+
}
7381
}
74-
if (is_last_interesting_node(n)) {
82+
if (!is_interesting(n)) {
83+
results.erase(n);
84+
} else if (is_last_interesting_node(n)) {
7585
break;
7686
}
7787
}
@@ -95,7 +105,7 @@ NodeDependencyAnalysis NodeDependencyAnalysis::BackwardDependents(
95105
FunctionBase* fb, absl::Span<Node* const> nodes) {
96106
absl::flat_hash_set<Node*> interesting(nodes.begin(), nodes.end());
97107
auto [dependents, node_ids] = AnalyzeDependents(
98-
fb, interesting, [](Node* node) { return node->operands(); },
108+
fb, interesting, /*succs=*/[](Node* node) { return node->users(); },
99109
TopoSort(fb));
100110
return NodeDependencyAnalysis(/*is_forwards=*/false, dependents, node_ids);
101111
}
@@ -104,7 +114,7 @@ NodeDependencyAnalysis NodeDependencyAnalysis::ForwardDependents(
104114
FunctionBase* fb, absl::Span<Node* const> nodes) {
105115
absl::flat_hash_set<Node*> interesting(nodes.begin(), nodes.end());
106116
auto [dependents, node_ids] = AnalyzeDependents(
107-
fb, interesting, [](Node* node) { return node->users(); },
117+
fb, interesting, /*succs=*/[](Node* node) { return node->operands(); },
108118
ReverseTopoSort(fb));
109119
return NodeDependencyAnalysis(/*is_forwards=*/true, dependents, node_ids);
110120
}

0 commit comments

Comments
 (0)