@@ -262,6 +262,8 @@ struct
262262 clock : Clock .t ;
263263 mutable pending : Tcp.Id.Set .t ;
264264 mutable last_active_time : float ;
265+ (* Tasks that will be signalled if the endpoint is destroyed *)
266+ mutable on_destroy : unit Lwt .u Tcp.Id.Map .t ;
265267 }
266268 (* * A generic TCP/IP endpoint *)
267269
@@ -284,12 +286,17 @@ struct
284286
285287 let pending = Tcp.Id.Set. empty in
286288 let last_active_time = Unix. gettimeofday () in
289+ let on_destroy = Tcp.Id.Map. empty in
287290 let tcp_stack =
288291 { recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
289- last_active_time; clock }
292+ last_active_time; clock; on_destroy }
290293 in
291294 Lwt. return tcp_stack
292295
296+ let destroy t =
297+ Tcp.Id.Map. iter (fun _ u -> Lwt. wakeup_later u () ) t.on_destroy;
298+ t.on_destroy < - Tcp.Id.Map. empty
299+
293300 let intercept_tcp_syn t ~id ~syn on_syn_callback (buf : Cstruct.t ) =
294301 if syn then begin
295302 if Tcp.Id.Set. mem id t.pending then begin
@@ -300,9 +307,14 @@ struct
300307 Lwt. return_unit
301308 end else begin
302309 t.pending < - Tcp.Id.Set. add id t.pending;
310+ (* Add a task to the "on_destroy" list which will be signalled if
311+ the Endpoint is disconnected from the switch and we should close
312+ connections. *)
313+ let close, close_request = Lwt. task () in
314+ t.on_destroy < - Tcp.Id.Map. add id close_request t.on_destroy;
303315 Lwt. finalize
304316 (fun () ->
305- on_syn_callback ()
317+ on_syn_callback close
306318 >> = fun listeners ->
307319 let src = Stack_tcp_wire. dst id in
308320 let dst = Stack_tcp_wire. src id in
@@ -324,7 +336,7 @@ struct
324336 Mirage_flow_lwt. Proxy (Clock )(Stack_tcp )(Host.Sockets.Stream. Tcp )
325337
326338 let input_tcp t ~id ~syn (ip , port ) (buf : Cstruct.t ) =
327- intercept_tcp_syn t ~id ~syn (fun () ->
339+ intercept_tcp_syn t ~id ~syn (fun close ->
328340 Host.Sockets.Stream.Tcp. connect (ip, port)
329341 >> = function
330342 | Error (`Msg m ) ->
@@ -346,9 +358,21 @@ struct
346358 Lwt. return_unit
347359 | Some socket ->
348360 Lwt. finalize (fun () ->
349- Proxy. proxy t.clock flow socket
361+ Lwt. pick [
362+ Lwt. map
363+ (function Error e -> Error (`Proxy e) | Ok x -> Ok x)
364+ (Proxy. proxy t.clock flow socket);
365+ Lwt. map
366+ (fun () -> Error `Close )
367+ close
368+ ]
350369 >> = function
351- | Error e ->
370+ | Error (`Close) ->
371+ Log. info (fun f ->
372+ f " %s proxy closed due to switch port disconnection"
373+ (Tcp.Flow. to_string tcp));
374+ Lwt. return_unit
375+ | Error (`Proxy e ) ->
352376 Log. debug (fun f ->
353377 f " %s proxy failed with %a"
354378 (Tcp.Flow. to_string tcp) Proxy. pp_error e);
@@ -359,6 +383,7 @@ struct
359383 Log. debug (fun f ->
360384 f " closing flow %s" (string_of_id tcp.Tcp.Flow. id));
361385 tcp.Tcp.Flow. socket < - None ;
386+ t.on_destroy < - Tcp.Id.Map. remove id t.on_destroy;
362387 Tcp.Flow. remove tcp.Tcp.Flow. id;
363388 Host.Sockets.Stream.Tcp. close socket
364389 )
@@ -484,9 +509,9 @@ struct
484509 let id =
485510 Stack_tcp_wire. v ~src_port: 53 ~dst: src ~src: dst ~dst_port: src_port
486511 in
487- Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
512+ Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
488513 ! dns >> = fun t ->
489- Dns_forwarder. handle_tcp ~t
514+ Dns_forwarder. handle_tcp ~t ~close
490515 ) raw
491516 > |= ok
492517
@@ -808,10 +833,11 @@ struct
808833 let now = Unix. gettimeofday () in
809834 let old_ips = IPMap. fold (fun ip endpoint acc ->
810835 let age = now -. endpoint.Endpoint. last_active_time in
811- if age > (float_of_int port_max_idle_time) then ip :: acc else acc
836+ if age > (float_of_int port_max_idle_time) then (ip, endpoint) :: acc else acc
812837 ) t.endpoints [] in
813- List. iter (fun ip ->
838+ List. iter (fun ( ip , endpoint ) ->
814839 Switch. remove t.switch ip;
840+ Endpoint. destroy endpoint;
815841 t.endpoints < - IPMap. remove ip t.endpoints
816842 ) old_ips;
817843 Lwt. return_unit
0 commit comments