Skip to content

Commit a9c5528

Browse files
committed
tsort: use iterative dfs to avoid crashes on stack overflow
1 parent 2db9f0b commit a9c5528

File tree

2 files changed

+90
-27
lines changed

2 files changed

+90
-27
lines changed

src/uu/tsort/src/tsort.rs

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
// file that was distributed with this source code.
55
//spell-checker:ignore TAOCP indegree
66
use clap::{Arg, Command};
7-
use std::collections::{HashMap, HashSet, VecDeque};
7+
use std::collections::hash_map::Entry;
8+
use std::collections::{HashMap, VecDeque};
89
use std::ffi::OsString;
910
use std::path::Path;
1011
use thiserror::Error;
@@ -34,13 +35,15 @@ enum TsortError {
3435
/// The graph contains a cycle.
3536
#[error("{input}: {message}", input = .0, message = translate!("tsort-error-loop"))]
3637
Loop(String),
37-
38-
/// A particular node in a cycle. (This is mainly used for printing.)
39-
#[error("{0}")]
40-
LoopNode(String),
4138
}
4239

40+
// Auxiliary struct, just for printing loop nodes via show! macro
41+
#[derive(Debug, Error)]
42+
#[error("{0}")]
43+
struct LoopNode<'a>(&'a str);
44+
4345
impl UError for TsortError {}
46+
impl UError for LoopNode<'_> {}
4447

4548
#[uucore::main]
4649
pub fn uumain(args: impl uucore::Args) -> UResult<()> {
@@ -131,6 +134,12 @@ struct Graph<'input> {
131134
nodes: HashMap<&'input str, Node<'input>>,
132135
}
133136

137+
#[derive(Clone, Copy, PartialEq, Eq)]
138+
enum VisitedState {
139+
Opened,
140+
Closed,
141+
}
142+
134143
impl<'input> Graph<'input> {
135144
fn new(name: String) -> Graph<'input> {
136145
Self {
@@ -224,8 +233,8 @@ impl<'input> Graph<'input> {
224233
fn find_and_break_cycle(&mut self, frontier: &mut VecDeque<&'input str>) {
225234
let cycle = self.detect_cycle();
226235
show!(TsortError::Loop(self.name.clone()));
227-
for node in &cycle {
228-
show!(TsortError::LoopNode((*node).to_string()));
236+
for &node in &cycle {
237+
show!(LoopNode(node));
229238
}
230239
let u = cycle[0];
231240
let v = cycle[1];
@@ -240,41 +249,63 @@ impl<'input> Graph<'input> {
240249
let mut nodes: Vec<_> = self.nodes.keys().collect();
241250
nodes.sort_unstable();
242251

243-
let mut visited = HashSet::new();
252+
let mut visited = HashMap::new();
244253
let mut stack = Vec::with_capacity(self.nodes.len());
245254
for node in nodes {
246-
if !visited.contains(node) && self.dfs(node, &mut visited, &mut stack) {
247-
return stack;
255+
if self.dfs(node, &mut visited, &mut stack) {
256+
return stack.into_iter().map(|(node, _)| node).collect();
248257
}
249258
}
250-
unreachable!();
259+
unreachable!("detect_cycle is expected to be called only on graphs with cycles");
251260
}
252261

253-
fn dfs(
254-
&self,
262+
fn dfs<'a>(
263+
&'a self,
255264
node: &'input str,
256-
visited: &mut HashSet<&'input str>,
257-
stack: &mut Vec<&'input str>,
265+
visited: &mut HashMap<&'input str, VisitedState>,
266+
stack: &mut Vec<(&'input str, &'a [&'input str])>,
258267
) -> bool {
259-
if stack.contains(&node) {
260-
return true;
261-
}
262-
if visited.contains(&node) {
268+
stack.push((
269+
node,
270+
self.nodes.get(node).map_or(&[], |n| &n.successor_names),
271+
));
272+
let state = *visited.entry(node).or_insert(VisitedState::Opened);
273+
274+
if state == VisitedState::Closed {
263275
return false;
264276
}
265277

266-
visited.insert(node);
267-
stack.push(node);
268-
269-
if let Some(successor_names) = self.nodes.get(node).map(|n| &n.successor_names) {
270-
for &successor in successor_names {
271-
if self.dfs(successor, visited, stack) {
272-
return true;
278+
while let Some((node, pending_successors)) = stack.pop() {
279+
let Some((&next_node, pending)) = pending_successors.split_first() else {
280+
// no more pending successors in the list -> close the node
281+
visited.insert(node, VisitedState::Closed);
282+
continue;
283+
};
284+
285+
// schedule processing for the pending part of successors for this node
286+
stack.push((node, pending));
287+
288+
match visited.entry(next_node) {
289+
Entry::Vacant(v) => {
290+
// It's a first time we enter this node
291+
v.insert(VisitedState::Opened);
292+
stack.push((
293+
next_node,
294+
self.nodes
295+
.get(next_node)
296+
.map_or(&[], |n| &n.successor_names),
297+
));
298+
}
299+
Entry::Occupied(o) => {
300+
if *o.get() == VisitedState::Opened {
301+
// we are entering the same opened node again -> loop found
302+
// stack contains it
303+
return true;
304+
}
273305
}
274306
}
275307
}
276308

277-
stack.pop();
278309
false
279310
}
280311
}

tests/by-util/test_tsort.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,35 @@ fn test_two_cycles() {
122122
.stdout_is("a\nc\nd\nb\n")
123123
.stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\ntsort: -: input contains a loop:\ntsort: b\ntsort: d\n");
124124
}
125+
126+
#[test]
127+
fn test_long_loop_no_stack_overflow() {
128+
use std::fmt::Write;
129+
const N: usize = 100_000;
130+
let mut input = String::new();
131+
for v in 0..N {
132+
let next = (v + 1) % N;
133+
let _ = write!(input, "{v} {next} ");
134+
}
135+
new_ucmd!()
136+
.pipe_in(input)
137+
.fails_with_code(1)
138+
.stderr_contains("tsort: -: input contains a loop");
139+
}
140+
141+
#[test]
142+
fn test_loop_for_iterative_dfs_correctness() {
143+
let input = r#"
144+
A B
145+
A C
146+
B D
147+
D C
148+
C D
149+
"#;
150+
// If dfs iteration marks nodes as visited too erly (node C), node D will be marked as
151+
// closed and fails cycle detection
152+
new_ucmd!()
153+
.pipe_in(input)
154+
.fails_with_code(1)
155+
.stderr_contains("tsort: -: input contains a loop");
156+
}

0 commit comments

Comments
 (0)