2727#include " absl/algorithm/container.h"
2828#include " absl/container/btree_set.h"
2929#include " absl/container/flat_hash_map.h"
30+ #include " absl/container/flat_hash_set.h"
3031#include " absl/log/check.h"
3132#include " absl/status/status.h"
3233#include " absl/status/statusor.h"
@@ -178,7 +179,7 @@ struct SelectorAndArm {
178179
179180struct EquivalenceSet {
180181 std::vector<PredicateState> equivalent_states;
181- InlineBitmap interesting_nodes;
182+ absl::flat_hash_set<Node*> interesting_nodes;
182183};
183184
184185// Helper to perform the actual analysis and hold together all data needed.
@@ -187,8 +188,8 @@ struct EquivalenceSet {
187188class Analysis {
188189 public:
189190 struct InterestingStatesAndNodeList {
190- absl::flat_hash_map <Node*, int64_t > node_indices;
191- std::vector<std::pair<PredicateState, InlineBitmap>> state_and_nodes;
191+ std::vector<std::pair<PredicateState, absl::flat_hash_set <Node*>>>
192+ state_and_nodes;
192193 };
193194 Analysis (
194195 RangeQueryEngine& base_range,
@@ -230,18 +231,11 @@ class Analysis {
230231 absl::flat_hash_map<SelectorAndArm, EquivalenceSet> equivalences;
231232 equivalences.reserve (interesting.state_and_nodes .size ());
232233 for (const auto & [state, interesting_nodes] : interesting.state_and_nodes ) {
233- EquivalenceSet& cur =
234- equivalences
235- .try_emplace (
236- SelectorAndArm{
237- .selector = state.node ()->As <Select>()->selector (),
238- .arm = state.arm ()},
239- EquivalenceSet{
240- .equivalent_states = {},
241- .interesting_nodes = InlineBitmap (f->node_count ())})
242- .first ->second ;
243- cur.equivalent_states .push_back (state);
244- cur.interesting_nodes .Union (interesting_nodes);
234+ equivalences.try_emplace (
235+ SelectorAndArm{.selector = state.node ()->As <Select>()->selector (),
236+ .arm = state.arm ()},
237+ EquivalenceSet{.equivalent_states = {state},
238+ .interesting_nodes = interesting_nodes});
245239 }
246240 // NB We don't care what order we examine each equivalence (since all are
247241 // disjoint).
@@ -253,8 +247,7 @@ class Analysis {
253247 // to be true at a time.
254248 XLS_ASSIGN_OR_RETURN (auto tmp,
255249 CalculateRangeGiven (states.equivalent_states .back (),
256- states.interesting_nodes ,
257- interesting.node_indices ));
250+ states.interesting_nodes ));
258251 auto result =
259252 arena_
260253 .emplace_back (std::make_unique<RangeQueryEngine>(std::move (tmp)))
@@ -286,44 +279,45 @@ class Analysis {
286279 absl::c_copy (interesting, std::back_inserter (select_nodes));
287280 selectee_nodes.push_back (s.value ());
288281 }
289- NodeDependencyAnalysis forward_interesting (
290- NodeDependencyAnalysis::ForwardDependents (f, select_nodes));
291- NodeDependencyAnalysis backwards_interesting (
292- NodeDependencyAnalysis::BackwardDependents (f, selectee_nodes));
293- std::vector<std::pair<PredicateState, InlineBitmap>> interesting_states;
282+ NodeForwardDependencyAnalysis forward_deps;
283+ NodeBackwardDependencyAnalysis backward_deps;
284+ std::vector<std::pair<PredicateState, absl::flat_hash_set<Node*>>>
285+ interesting_states;
294286 interesting_states.reserve (states.size ());
295287 for (const PredicateState& ps : states) {
296288 // If there's any node which is both an input into the select value and
297289 // affected by something the conditional specialization can discover we
298290 // consider it interesting.
299- InlineBitmap forward_bm (f-> node_count (), false ) ;
291+ absl::flat_hash_set<Node*> depending_on_selector ;
300292 // What nodes do we care about for this specific run. Since this basically
301293 // only depends on the input node no need to memoize it.
302294 XLS_ASSIGN_OR_RETURN (
303295 std::vector<Node*> interesting,
304296 InterestingNodeFinder::Execute (base_range_, ps.selector ()));
305- for (Node* n : interesting) {
306- XLS_ASSIGN_OR_RETURN (auto deps, forward_interesting.GetDependents (n));
307- forward_bm.Union (deps.bitmap ());
308- }
309- XLS_ASSIGN_OR_RETURN (auto backwards_bm,
310- backwards_interesting.GetDependents (ps.value ()));
311297 // Nodes that the selector affects & nodes the selected value is affected
312298 // by is the set of nodes with potentially changed conditional ranges.
313- InlineBitmap final_bm = forward_bm;
314- final_bm.Intersect (backwards_bm.bitmap ());
315- if (!final_bm.IsAllZeroes ()) {
299+ absl::flat_hash_set<Node*> depended_on_by_value =
300+ forward_deps.NodesDependedOnBy (ps.value ());
301+ absl::flat_hash_set<Node*> final_deps;
302+ for (Node* n : interesting) {
303+ auto deps = backward_deps.NodesDependingOn (n);
304+ depending_on_selector.insert (deps.begin (), deps.end ());
305+ for (Node* dep : deps) {
306+ if (depended_on_by_value.contains (dep)) {
307+ final_deps.insert (dep);
308+ }
309+ }
310+ }
311+ if (!final_deps.empty ()) {
316312 // nodes affected by the known data are the ones we need to recalculate.
317- interesting_states.push_back ({ps, std::move (forward_bm )});
313+ interesting_states.push_back ({ps, std::move (depending_on_selector )});
318314 }
319315 }
320- return InterestingStatesAndNodeList{
321- .node_indices = backwards_interesting.node_indices (),
322- .state_and_nodes = interesting_states};
316+ return InterestingStatesAndNodeList{.state_and_nodes = interesting_states};
323317 }
324318 absl::StatusOr<RangeQueryEngine> CalculateRangeGiven (
325- PredicateState s, const InlineBitmap& interesting_nodes,
326- const absl::flat_hash_map <Node*, int64_t >& node_ids ) const {
319+ PredicateState s,
320+ const absl::flat_hash_set <Node*>& interesting_nodes ) const {
327321 RangeQueryEngine result;
328322 XLS_ASSIGN_OR_RETURN (
329323 (std::optional<absl::flat_hash_map<Node*, RangeData>> known_data),
@@ -337,7 +331,7 @@ class Analysis {
337331 ContextGivens givens (
338332 topo_sort_, s.node (), *known_data,
339333 [&](Node* n) -> std::optional<RangeData> {
340- if (interesting_nodes.Get (node_ids. at (n) )) {
334+ if (interesting_nodes.contains (n )) {
341335 // Affected by known data.
342336 return std::nullopt ;
343337 }
0 commit comments