4
4
// file that was distributed with this source code.
5
5
//spell-checker:ignore TAOCP indegree
6
6
use clap:: { Arg , Command } ;
7
- use std:: collections:: { HashMap , HashSet , VecDeque } ;
7
+ use std:: collections:: hash_map:: Entry ;
8
+ use std:: collections:: { HashMap , VecDeque } ;
8
9
use std:: ffi:: OsString ;
9
10
use std:: path:: Path ;
10
11
use thiserror:: Error ;
@@ -34,13 +35,15 @@ enum TsortError {
34
35
/// The graph contains a cycle.
35
36
#[ error( "{input}: {message}" , input = . 0 , message = translate!( "tsort-error-loop" ) ) ]
36
37
Loop ( String ) ,
37
-
38
- /// A particular node in a cycle. (This is mainly used for printing.)
39
- #[ error( "{0}" ) ]
40
- LoopNode ( String ) ,
41
38
}
42
39
40
+ // Auxiliary struct, just for printing loop nodes via show! macro
41
+ #[ derive( Debug , Error ) ]
42
+ #[ error( "{0}" ) ]
43
+ struct LoopNode < ' a > ( & ' a str ) ;
44
+
43
45
impl UError for TsortError { }
46
+ impl UError for LoopNode < ' _ > { }
44
47
45
48
#[ uucore:: main]
46
49
pub fn uumain ( args : impl uucore:: Args ) -> UResult < ( ) > {
@@ -131,6 +134,13 @@ struct Graph<'input> {
131
134
nodes : HashMap < & ' input str , Node < ' input > > ,
132
135
}
133
136
137
+ #[ derive( Clone , Copy , PartialEq , Eq ) ]
138
+ enum VisitedState {
139
+ JustAdded ,
140
+ Opened ,
141
+ Closed ,
142
+ }
143
+
134
144
impl < ' input > Graph < ' input > {
135
145
fn new ( name : String ) -> Graph < ' input > {
136
146
Self {
@@ -224,8 +234,8 @@ impl<'input> Graph<'input> {
224
234
fn find_and_break_cycle ( & mut self , frontier : & mut VecDeque < & ' input str > ) {
225
235
let cycle = self . detect_cycle ( ) ;
226
236
show ! ( TsortError :: Loop ( self . name. clone( ) ) ) ;
227
- for node in & cycle {
228
- show ! ( TsortError :: LoopNode ( ( * node) . to_string ( ) ) ) ;
237
+ for & node in & cycle {
238
+ show ! ( LoopNode ( node) ) ;
229
239
}
230
240
let u = cycle[ 0 ] ;
231
241
let v = cycle[ 1 ] ;
@@ -240,41 +250,68 @@ impl<'input> Graph<'input> {
240
250
let mut nodes: Vec < _ > = self . nodes . keys ( ) . collect ( ) ;
241
251
nodes. sort_unstable ( ) ;
242
252
243
- let mut visited = HashSet :: new ( ) ;
253
+ let mut visited = HashMap :: new ( ) ;
244
254
let mut stack = Vec :: with_capacity ( self . nodes . len ( ) ) ;
245
255
for node in nodes {
246
- if !visited . contains ( node ) && self . dfs ( node, & mut visited, & mut stack) {
256
+ if self . dfs ( node, & mut visited, & mut stack) {
247
257
return stack;
248
258
}
249
259
}
250
- unreachable ! ( ) ;
260
+ unreachable ! ( "detect_cycle is expected to be called only on graphs with cycles" ) ;
251
261
}
252
262
253
263
fn dfs (
254
264
& self ,
255
265
node : & ' input str ,
256
- visited : & mut HashSet < & ' input str > ,
266
+ visited : & mut HashMap < & ' input str , VisitedState > ,
257
267
stack : & mut Vec < & ' input str > ,
258
268
) -> bool {
259
- if stack. contains ( & node) {
260
- return true ;
261
- }
262
- if visited. contains ( & node) {
263
- return false ;
264
- }
265
-
266
- visited. insert ( node) ;
267
269
stack. push ( node) ;
268
270
269
- if let Some ( successor_names) = self . nodes . get ( node) . map ( |n| & n. successor_names ) {
270
- for & successor in successor_names {
271
- if self . dfs ( successor, visited, stack) {
272
- return true ;
271
+ while let Some ( node) = stack. pop ( ) {
272
+ let state = visited. entry ( node) . or_insert ( VisitedState :: JustAdded ) ;
273
+
274
+ match * state {
275
+ VisitedState :: Closed => continue ,
276
+ VisitedState :: Opened => {
277
+ // close the node if it was opened
278
+ * state = VisitedState :: Closed ;
279
+ continue ;
280
+ }
281
+ VisitedState :: JustAdded => {
282
+ // open the node, put it on the stack for tracking & to return
283
+ // to this node at the end of processing all its successors
284
+ * state = VisitedState :: Opened ;
285
+ stack. push ( node) ;
286
+ }
287
+ }
288
+
289
+ let Some ( successors) = self . nodes . get ( node) . map ( |n| & n. successor_names ) else {
290
+ continue ;
291
+ } ;
292
+
293
+ // to preserve same output as for recursive version:
294
+ // iterate over successors in reverse order -- because all of them are going to be
295
+ // inserted on the stack, and popped back in reverse order
296
+ for & successor in successors. iter ( ) . rev ( ) {
297
+ match visited. entry ( successor) {
298
+ Entry :: Vacant ( v) => {
299
+ v. insert ( VisitedState :: JustAdded ) ;
300
+ stack. push ( successor) ;
301
+ }
302
+ Entry :: Occupied ( o) => {
303
+ if * o. get ( ) == VisitedState :: Opened {
304
+ // we are entering the same node --
305
+ // we found loop & stack contains this loop
306
+ // in form of Opened nodes. Drop all other nodes
307
+ stack. retain ( |& n| visited. get ( n) == Some ( & VisitedState :: Opened ) ) ;
308
+ return true ;
309
+ }
310
+ }
273
311
}
274
312
}
275
313
}
276
314
277
- stack. pop ( ) ;
278
315
false
279
316
}
280
317
}
0 commit comments