Skip to content

Commit d8789db

Browse files
angelomatni1copybara-github
authored andcommitted
Make NodeDependencyAnalysis lazily computed; use in resource sharing pass
PiperOrigin-RevId: 829051198
1 parent 0e0a32e commit d8789db

11 files changed

+335
-703
lines changed

xls/dev_tools/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ cc_library(
207207
"//xls/passes:node_dependency_analysis",
208208
"@com_google_absl//absl/algorithm:container",
209209
"@com_google_absl//absl/container:flat_hash_map",
210+
"@com_google_absl//absl/container:flat_hash_set",
210211
"@com_google_absl//absl/status",
211212
"@com_google_absl//absl/status:statusor",
212213
"@com_google_absl//absl/strings:str_format",

xls/dev_tools/extract_segment.cc

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
#include "xls/dev_tools/extract_segment.h"
1616

1717
#include <memory>
18-
#include <optional>
1918
#include <string_view>
2019
#include <vector>
2120

2221
#include "absl/algorithm/container.h"
2322
#include "absl/container/flat_hash_map.h"
23+
#include "absl/container/flat_hash_set.h"
2424
#include "absl/status/status.h"
2525
#include "absl/status/statusor.h"
2626
#include "absl/strings/str_format.h"
@@ -46,37 +46,27 @@ absl::StatusOr<BValue> ExtractSegmentInto(
4646
absl::flat_hash_map<Node*, Node*>* old_to_new_map,
4747
bool next_nodes_are_tuples) {
4848
// Get node dependency information.
49-
std::optional<NodeDependencyAnalysis> forward_analysis;
50-
std::optional<NodeDependencyAnalysis> backward_analysis;
51-
std::vector<DependencyBitmap> forward_bitmaps;
52-
std::vector<DependencyBitmap> backward_bitmaps;
49+
absl::flat_hash_set<Node*> forward_deps;
50+
absl::flat_hash_set<Node*> backward_deps;
5351
if (!source_nodes.empty()) {
54-
forward_analysis =
55-
NodeDependencyAnalysis::ForwardDependents(full, source_nodes);
52+
NodeBackwardDependencyAnalysis backward_analysis;
5653
for (auto n : source_nodes) {
57-
XLS_ASSIGN_OR_RETURN(auto dep, forward_analysis->GetDependents(n));
58-
forward_bitmaps.push_back(dep);
54+
auto tos = backward_analysis.NodesDependingOn(n);
55+
backward_deps.insert(tos.begin(), tos.end());
5956
}
6057
}
6158
if (!sink_nodes.empty()) {
62-
backward_analysis =
63-
NodeDependencyAnalysis::BackwardDependents(full, sink_nodes);
59+
NodeForwardDependencyAnalysis forward_analysis;
6460
for (auto n : sink_nodes) {
65-
XLS_ASSIGN_OR_RETURN(auto dep, backward_analysis->GetDependents(n));
66-
backward_bitmaps.push_back(dep);
61+
auto froms = forward_analysis.NodesDependedOnBy(n);
62+
forward_deps.insert(froms.begin(), froms.end());
6763
}
6864
}
6965
auto is_forward_dep = [&](Node* n) -> bool {
70-
return forward_bitmaps.empty() ||
71-
absl::c_any_of(forward_bitmaps, [&](DependencyBitmap db) {
72-
return *db.IsDependent(n);
73-
});
66+
return forward_deps.empty() || forward_deps.contains(n);
7467
};
7568
auto is_backward_dep = [&](Node* n) -> bool {
76-
return backward_bitmaps.empty() ||
77-
absl::c_any_of(backward_bitmaps, [&](DependencyBitmap db) {
78-
return *db.IsDependent(n);
79-
});
69+
return backward_deps.empty() || backward_deps.contains(n);
8070
};
8171
auto is_dep = [&](Node* n) -> bool {
8272
return is_forward_dep(n) && is_backward_dep(n);

xls/passes/BUILD

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,7 @@ cc_library(
953953
"@com_google_absl//absl/algorithm:container",
954954
"@com_google_absl//absl/container:btree",
955955
"@com_google_absl//absl/container:flat_hash_map",
956+
"@com_google_absl//absl/container:flat_hash_set",
956957
"@com_google_absl//absl/log:check",
957958
"@com_google_absl//absl/status",
958959
"@com_google_absl//absl/status:statusor",
@@ -1061,6 +1062,7 @@ xls_pass(
10611062
deps = [
10621063
":bdd_query_engine",
10631064
":folding_graph",
1065+
":node_dependency_analysis",
10641066
":optimization_pass",
10651067
":pass_base",
10661068
":post_dominator_analysis",
@@ -1083,8 +1085,6 @@ xls_pass(
10831085
"@com_google_absl//absl/random",
10841086
"@com_google_absl//absl/status:statusor",
10851087
"@com_google_absl//absl/types:span",
1086-
"@com_google_ortools//ortools/graph",
1087-
"@com_google_ortools//ortools/graph:cliques",
10881088
"@cppitertools",
10891089
],
10901090
)
@@ -1354,16 +1354,10 @@ cc_library(
13541354
srcs = ["node_dependency_analysis.cc"],
13551355
hdrs = ["node_dependency_analysis.h"],
13561356
deps = [
1357-
"//xls/common/status:status_macros",
1358-
"//xls/data_structures:inline_bitmap",
1357+
":lazy_node_data",
13591358
"//xls/ir",
1360-
"@com_google_absl//absl/base:core_headers",
1361-
"@com_google_absl//absl/container:flat_hash_map",
13621359
"@com_google_absl//absl/container:flat_hash_set",
1363-
"@com_google_absl//absl/log",
1364-
"@com_google_absl//absl/log:check",
13651360
"@com_google_absl//absl/status",
1366-
"@com_google_absl//absl/status:statusor",
13671361
"@com_google_absl//absl/types:span",
13681362
],
13691363
)
@@ -4342,6 +4336,7 @@ cc_test(
43424336
srcs = ["resource_sharing_pass_test.cc"],
43434337
deps = [
43444338
":bdd_query_engine",
4339+
":node_dependency_analysis",
43454340
":optimization_pass",
43464341
":pass_base",
43474342
":query_engine",

xls/passes/context_sensitive_range_query_engine.cc

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/algorithm/container.h"
2828
#include "absl/container/btree_set.h"
2929
#include "absl/container/flat_hash_map.h"
30+
#include "absl/container/flat_hash_set.h"
3031
#include "absl/log/check.h"
3132
#include "absl/status/status.h"
3233
#include "absl/status/statusor.h"
@@ -178,7 +179,7 @@ struct SelectorAndArm {
178179

179180
struct EquivalenceSet {
180181
std::vector<PredicateState> equivalent_states;
181-
InlineBitmap interesting_nodes;
182+
absl::flat_hash_set<Node*> interesting_nodes;
182183
};
183184

184185
// Helper to perform the actual analysis and hold together all data needed.
@@ -187,8 +188,8 @@ struct EquivalenceSet {
187188
class Analysis {
188189
public:
189190
struct InterestingStatesAndNodeList {
190-
absl::flat_hash_map<Node*, int64_t> node_indices;
191-
std::vector<std::pair<PredicateState, InlineBitmap>> state_and_nodes;
191+
std::vector<std::pair<PredicateState, absl::flat_hash_set<Node*>>>
192+
state_and_nodes;
192193
};
193194
Analysis(
194195
RangeQueryEngine& base_range,
@@ -230,18 +231,11 @@ class Analysis {
230231
absl::flat_hash_map<SelectorAndArm, EquivalenceSet> equivalences;
231232
equivalences.reserve(interesting.state_and_nodes.size());
232233
for (const auto& [state, interesting_nodes] : interesting.state_and_nodes) {
233-
EquivalenceSet& cur =
234-
equivalences
235-
.try_emplace(
236-
SelectorAndArm{
237-
.selector = state.node()->As<Select>()->selector(),
238-
.arm = state.arm()},
239-
EquivalenceSet{
240-
.equivalent_states = {},
241-
.interesting_nodes = InlineBitmap(f->node_count())})
242-
.first->second;
243-
cur.equivalent_states.push_back(state);
244-
cur.interesting_nodes.Union(interesting_nodes);
234+
equivalences.try_emplace(
235+
SelectorAndArm{.selector = state.node()->As<Select>()->selector(),
236+
.arm = state.arm()},
237+
EquivalenceSet{.equivalent_states = {state},
238+
.interesting_nodes = interesting_nodes});
245239
}
246240
// NB We don't care what order we examine each equivalence (since all are
247241
// disjoint).
@@ -253,8 +247,7 @@ class Analysis {
253247
// to be true at a time.
254248
XLS_ASSIGN_OR_RETURN(auto tmp,
255249
CalculateRangeGiven(states.equivalent_states.back(),
256-
states.interesting_nodes,
257-
interesting.node_indices));
250+
states.interesting_nodes));
258251
auto result =
259252
arena_
260253
.emplace_back(std::make_unique<RangeQueryEngine>(std::move(tmp)))
@@ -286,44 +279,45 @@ class Analysis {
286279
absl::c_copy(interesting, std::back_inserter(select_nodes));
287280
selectee_nodes.push_back(s.value());
288281
}
289-
NodeDependencyAnalysis forward_interesting(
290-
NodeDependencyAnalysis::ForwardDependents(f, select_nodes));
291-
NodeDependencyAnalysis backwards_interesting(
292-
NodeDependencyAnalysis::BackwardDependents(f, selectee_nodes));
293-
std::vector<std::pair<PredicateState, InlineBitmap>> interesting_states;
282+
NodeForwardDependencyAnalysis forward_deps;
283+
NodeBackwardDependencyAnalysis backward_deps;
284+
std::vector<std::pair<PredicateState, absl::flat_hash_set<Node*>>>
285+
interesting_states;
294286
interesting_states.reserve(states.size());
295287
for (const PredicateState& ps : states) {
296288
// If there's any node which is both an input into the select value and
297289
// affected by something the conditional specialization can discover we
298290
// consider it interesting.
299-
InlineBitmap forward_bm(f->node_count(), false);
291+
absl::flat_hash_set<Node*> depending_on_selector;
300292
// What nodes do we care about for this specific run. Since this basically
301293
// only depends on the input node no need to memoize it.
302294
XLS_ASSIGN_OR_RETURN(
303295
std::vector<Node*> interesting,
304296
InterestingNodeFinder::Execute(base_range_, ps.selector()));
305-
for (Node* n : interesting) {
306-
XLS_ASSIGN_OR_RETURN(auto deps, forward_interesting.GetDependents(n));
307-
forward_bm.Union(deps.bitmap());
308-
}
309-
XLS_ASSIGN_OR_RETURN(auto backwards_bm,
310-
backwards_interesting.GetDependents(ps.value()));
311297
// Nodes that the selector affects & nodes the selected value is affected
312298
// by is the set of nodes with potentially changed conditional ranges.
313-
InlineBitmap final_bm = forward_bm;
314-
final_bm.Intersect(backwards_bm.bitmap());
315-
if (!final_bm.IsAllZeroes()) {
299+
absl::flat_hash_set<Node*> depended_on_by_value =
300+
forward_deps.NodesDependedOnBy(ps.value());
301+
absl::flat_hash_set<Node*> final_deps;
302+
for (Node* n : interesting) {
303+
auto deps = backward_deps.NodesDependingOn(n);
304+
depending_on_selector.insert(deps.begin(), deps.end());
305+
for (Node* dep : deps) {
306+
if (depended_on_by_value.contains(dep)) {
307+
final_deps.insert(dep);
308+
}
309+
}
310+
}
311+
if (!final_deps.empty()) {
316312
// nodes affected by the known data are the ones we need to recalculate.
317-
interesting_states.push_back({ps, std::move(forward_bm)});
313+
interesting_states.push_back({ps, std::move(depending_on_selector)});
318314
}
319315
}
320-
return InterestingStatesAndNodeList{
321-
.node_indices = backwards_interesting.node_indices(),
322-
.state_and_nodes = interesting_states};
316+
return InterestingStatesAndNodeList{.state_and_nodes = interesting_states};
323317
}
324318
absl::StatusOr<RangeQueryEngine> CalculateRangeGiven(
325-
PredicateState s, const InlineBitmap& interesting_nodes,
326-
const absl::flat_hash_map<Node*, int64_t>& node_ids) const {
319+
PredicateState s,
320+
const absl::flat_hash_set<Node*>& interesting_nodes) const {
327321
RangeQueryEngine result;
328322
XLS_ASSIGN_OR_RETURN(
329323
(std::optional<absl::flat_hash_map<Node*, RangeData>> known_data),
@@ -337,7 +331,7 @@ class Analysis {
337331
ContextGivens givens(
338332
topo_sort_, s.node(), *known_data,
339333
[&](Node* n) -> std::optional<RangeData> {
340-
if (interesting_nodes.Get(node_ids.at(n))) {
334+
if (interesting_nodes.contains(n)) {
341335
// Affected by known data.
342336
return std::nullopt;
343337
}

0 commit comments

Comments
 (0)