Skip to content

Commit 0d6ba6d

Browse files
committed
tsort: use iterative dfs to avoid crashes on stack overflow
1 parent ac7b145 commit 0d6ba6d

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

src/uu/tsort/src/tsort.rs

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
// file that was distributed with this source code.
55
//spell-checker:ignore TAOCP indegree
66
use clap::{Arg, Command};
7-
use std::collections::{HashMap, HashSet, VecDeque};
7+
use std::collections::hash_map::Entry;
8+
use std::collections::{HashMap, VecDeque};
89
use std::ffi::OsString;
910
use std::path::Path;
1011
use thiserror::Error;
@@ -34,13 +35,15 @@ enum TsortError {
3435
/// The graph contains a cycle.
3536
#[error("{input}: {message}", input = .0, message = translate!("tsort-error-loop"))]
3637
Loop(String),
37-
38-
/// A particular node in a cycle. (This is mainly used for printing.)
39-
#[error("{0}")]
40-
LoopNode(String),
4138
}
4239

40+
// Auxiliary struct, just for printing loop nodes via show! macro
41+
#[derive(Debug, Error)]
42+
#[error("{0}")]
43+
struct LoopNode<'a>(&'a str);
44+
4345
impl UError for TsortError {}
46+
impl UError for LoopNode<'_> {}
4447

4548
#[uucore::main]
4649
pub fn uumain(args: impl uucore::Args) -> UResult<()> {
@@ -131,6 +134,13 @@ struct Graph<'input> {
131134
nodes: HashMap<&'input str, Node<'input>>,
132135
}
133136

137+
#[derive(Clone, Copy, PartialEq, Eq)]
138+
enum VisitedState {
139+
JustAdded,
140+
Opened,
141+
Closed,
142+
}
143+
134144
impl<'input> Graph<'input> {
135145
fn new(name: String) -> Graph<'input> {
136146
Self {
@@ -224,8 +234,8 @@ impl<'input> Graph<'input> {
224234
fn find_and_break_cycle(&mut self, frontier: &mut VecDeque<&'input str>) {
225235
let cycle = self.detect_cycle();
226236
show!(TsortError::Loop(self.name.clone()));
227-
for node in &cycle {
228-
show!(TsortError::LoopNode((*node).to_string()));
237+
for &node in &cycle {
238+
show!(LoopNode(node));
229239
}
230240
let u = cycle[0];
231241
let v = cycle[1];
@@ -240,41 +250,68 @@ impl<'input> Graph<'input> {
240250
let mut nodes: Vec<_> = self.nodes.keys().collect();
241251
nodes.sort_unstable();
242252

243-
let mut visited = HashSet::new();
253+
let mut visited = HashMap::new();
244254
let mut stack = Vec::with_capacity(self.nodes.len());
245255
for node in nodes {
246-
if !visited.contains(node) && self.dfs(node, &mut visited, &mut stack) {
256+
if self.dfs(node, &mut visited, &mut stack) {
247257
return stack;
248258
}
249259
}
250-
unreachable!();
260+
unreachable!("detect_cycle is expected to be called only on graphs with cycles");
251261
}
252262

253263
fn dfs(
254264
&self,
255265
node: &'input str,
256-
visited: &mut HashSet<&'input str>,
266+
visited: &mut HashMap<&'input str, VisitedState>,
257267
stack: &mut Vec<&'input str>,
258268
) -> bool {
259-
if stack.contains(&node) {
260-
return true;
261-
}
262-
if visited.contains(&node) {
263-
return false;
264-
}
265-
266-
visited.insert(node);
267269
stack.push(node);
268270

269-
if let Some(successor_names) = self.nodes.get(node).map(|n| &n.successor_names) {
270-
for &successor in successor_names {
271-
if self.dfs(successor, visited, stack) {
272-
return true;
271+
while let Some(node) = stack.pop() {
272+
let state = visited.entry(node).or_insert(VisitedState::JustAdded);
273+
274+
match *state {
275+
VisitedState::Closed => continue,
276+
VisitedState::Opened => {
277+
// close the node if it was opened
278+
*state = VisitedState::Closed;
279+
continue;
280+
}
281+
VisitedState::JustAdded => {
282+
// open the node, put it on the stack for tracking & to return
283+
// to this node at the end of processing all its successors
284+
*state = VisitedState::Opened;
285+
stack.push(node);
286+
}
287+
}
288+
289+
let Some(successors) = self.nodes.get(node).map(|n| &n.successor_names) else {
290+
continue;
291+
};
292+
293+
// to preserve same output as for recursive version:
294+
// iterate over successors in reverse order -- because all of them are going to be
295+
// inserted on the stack, and popped back in reverse order
296+
for &successor in successors.iter().rev() {
297+
match visited.entry(successor) {
298+
Entry::Vacant(v) => {
299+
v.insert(VisitedState::JustAdded);
300+
stack.push(successor);
301+
}
302+
Entry::Occupied(o) => {
303+
if *o.get() == VisitedState::Opened {
304+
// we are entering the same node --
305+
// we found loop & stack contains this loop
306+
// in form of Opened nodes. Drop all other nodes
307+
stack.retain(|&n| visited.get(n) == Some(&VisitedState::Opened));
308+
return true;
309+
}
310+
}
273311
}
274312
}
275313
}
276314

277-
stack.pop();
278315
false
279316
}
280317
}

tests/by-util/test_tsort.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,17 @@ fn test_two_cycles() {
122122
.stdout_is("a\nc\nd\nb\n")
123123
.stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\ntsort: -: input contains a loop:\ntsort: b\ntsort: d\n");
124124
}
125+
126+
#[test]
127+
fn test_long_loop_no_stack_overflow() {
128+
let mut input = String::new();
129+
const N: usize = 100000;
130+
for v in 0..N {
131+
let next = (v + 1) % N;
132+
input.push_str(&format!("{v} {next} "));
133+
}
134+
new_ucmd!()
135+
.pipe_in(input)
136+
.fails_with_code(1)
137+
.stderr_contains("tsort: -: input contains a loop");
138+
}

0 commit comments

Comments
 (0)