1
1
use std:: io:: Cursor ;
2
+ use std:: io:: Read ;
3
+ use std:: io:: Write ;
2
4
use std:: io:: { self } ;
5
+ use std:: net:: SocketAddr ;
6
+ use std:: net:: TcpStream ;
3
7
use std:: path:: Path ;
4
8
use std:: path:: PathBuf ;
5
9
use std:: sync:: Arc ;
6
10
use std:: thread;
11
+ use std:: time:: Duration ;
7
12
8
13
use crate :: pkce:: PkceCodes ;
9
14
use crate :: pkce:: generate_pkce;
@@ -85,7 +90,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
85
90
let pkce = generate_pkce ( ) ;
86
91
let state = opts. force_state . clone ( ) . unwrap_or_else ( generate_state) ;
87
92
88
- let server = Server :: http ( format ! ( "127.0.0.1:{}" , opts. port) ) . map_err ( io :: Error :: other ) ?;
93
+ let server = bind_server ( opts. port ) ?;
89
94
let actual_port = match server. server_addr ( ) . to_ip ( ) {
90
95
Some ( addr) => addr. port ( ) ,
91
96
None => {
@@ -145,19 +150,24 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
145
150
let response =
146
151
process_request( & url_raw, & opts, & redirect_uri, & pkce, actual_port, & state) . await ;
147
152
148
- let is_login_complete = matches!( response, HandledRequest :: ResponseAndExit ( _) ) ;
149
- match response {
150
- HandledRequest :: Response ( r) | HandledRequest :: ResponseAndExit ( r) => {
151
- let _ = tokio:: task:: spawn_blocking( move || req. respond( r) ) . await ;
153
+ let exit_result = match response {
154
+ HandledRequest :: Response ( response) => {
155
+ let _ = tokio:: task:: spawn_blocking( move || req. respond( response) ) . await ;
156
+ None
157
+ }
158
+ HandledRequest :: ResponseAndExit { response, result } => {
159
+ let _ = tokio:: task:: spawn_blocking( move || req. respond( response) ) . await ;
160
+ Some ( result)
152
161
}
153
162
HandledRequest :: RedirectWithHeader ( header) => {
154
163
let redirect = Response :: empty( 302 ) . with_header( header) ;
155
164
let _ = tokio:: task:: spawn_blocking( move || req. respond( redirect) ) . await ;
165
+ None
156
166
}
157
- }
167
+ } ;
158
168
159
- if is_login_complete {
160
- break Ok ( ( ) ) ;
169
+ if let Some ( result ) = exit_result {
170
+ break result ;
161
171
}
162
172
}
163
173
}
@@ -181,7 +191,10 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
181
191
enum HandledRequest {
182
192
Response ( Response < Cursor < Vec < u8 > > > ) ,
183
193
RedirectWithHeader ( Header ) ,
184
- ResponseAndExit ( Response < Cursor < Vec < u8 > > > ) ,
194
+ ResponseAndExit {
195
+ response : Response < Cursor < Vec < u8 > > > ,
196
+ result : io:: Result < ( ) > ,
197
+ } ,
185
198
}
186
199
187
200
async fn process_request (
@@ -276,8 +289,18 @@ async fn process_request(
276
289
) {
277
290
resp. add_header ( h) ;
278
291
}
279
- HandledRequest :: ResponseAndExit ( resp)
292
+ HandledRequest :: ResponseAndExit {
293
+ response : resp,
294
+ result : Ok ( ( ) ) ,
295
+ }
280
296
}
297
+ "/cancel" => HandledRequest :: ResponseAndExit {
298
+ response : Response :: from_string ( "Login cancelled" ) ,
299
+ result : Err ( io:: Error :: new (
300
+ io:: ErrorKind :: Interrupted ,
301
+ "Login cancelled" ,
302
+ ) ) ,
303
+ } ,
281
304
_ => HandledRequest :: Response ( Response :: from_string ( "Not Found" ) . with_status_code ( 404 ) ) ,
282
305
}
283
306
}
@@ -316,6 +339,68 @@ fn generate_state() -> String {
316
339
base64:: engine:: general_purpose:: URL_SAFE_NO_PAD . encode ( bytes)
317
340
}
318
341
342
+ fn send_cancel_request ( port : u16 ) -> io:: Result < ( ) > {
343
+ let addr: SocketAddr = format ! ( "127.0.0.1:{port}" )
344
+ . parse ( )
345
+ . map_err ( |err| io:: Error :: new ( io:: ErrorKind :: InvalidInput , err) ) ?;
346
+ let mut stream = TcpStream :: connect_timeout ( & addr, Duration :: from_secs ( 2 ) ) ?;
347
+ stream. set_read_timeout ( Some ( Duration :: from_secs ( 2 ) ) ) ?;
348
+ stream. set_write_timeout ( Some ( Duration :: from_secs ( 2 ) ) ) ?;
349
+
350
+ stream. write_all ( b"GET /cancel HTTP/1.1\r \n " ) ?;
351
+ stream. write_all ( format ! ( "Host: 127.0.0.1:{port}\r \n " ) . as_bytes ( ) ) ?;
352
+ stream. write_all ( b"Connection: close\r \n \r \n " ) ?;
353
+
354
+ let mut buf = [ 0u8 ; 64 ] ;
355
+ let _ = stream. read ( & mut buf) ;
356
+ Ok ( ( ) )
357
+ }
358
+
359
+ fn bind_server ( port : u16 ) -> io:: Result < Server > {
360
+ let bind_address = format ! ( "127.0.0.1:{port}" ) ;
361
+ let mut cancel_attempted = false ;
362
+ let mut attempts = 0 ;
363
+ const MAX_ATTEMPTS : u32 = 10 ;
364
+ const RETRY_DELAY : Duration = Duration :: from_millis ( 200 ) ;
365
+
366
+ loop {
367
+ match Server :: http ( & bind_address) {
368
+ Ok ( server) => return Ok ( server) ,
369
+ Err ( err) => {
370
+ attempts += 1 ;
371
+ let is_addr_in_use = err
372
+ . downcast_ref :: < io:: Error > ( )
373
+ . map ( |io_err| io_err. kind ( ) == io:: ErrorKind :: AddrInUse )
374
+ . unwrap_or ( false ) ;
375
+
376
+ // If the address is in use, there is probably another instance of the login server
377
+ // running. Attempt to cancel it and retry.
378
+ if is_addr_in_use {
379
+ if !cancel_attempted {
380
+ cancel_attempted = true ;
381
+ if let Err ( cancel_err) = send_cancel_request ( port) {
382
+ eprintln ! ( "Failed to cancel previous login server: {cancel_err}" ) ;
383
+ }
384
+ }
385
+
386
+ thread:: sleep ( RETRY_DELAY ) ;
387
+
388
+ if attempts >= MAX_ATTEMPTS {
389
+ return Err ( io:: Error :: new (
390
+ io:: ErrorKind :: AddrInUse ,
391
+ format ! ( "Port {bind_address} is already in use" ) ,
392
+ ) ) ;
393
+ }
394
+
395
+ continue ;
396
+ }
397
+
398
+ return Err ( io:: Error :: other ( err) ) ;
399
+ }
400
+ }
401
+ }
402
+ }
403
+
319
404
struct ExchangedTokens {
320
405
id_token : String ,
321
406
access_token : String ,
0 commit comments