Skip to content

Commit 36b23db

Browse files
committed
Add a Endpoint.destroy which closes active connections
Previously there was no way to locate the connections associated with an endpoint to shut them down. This patch adds a map of TCP `id` to `unit Lwt.u` and a function `Endpoint.destroy` which triggers the disconnection of all the active connections. Related to #260 Signed-off-by: David Scott <[email protected]>
1 parent 47a530f commit 36b23db

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

src/hostnet/hostnet_dns.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ struct
356356
| Ok buffer ->
357357
Udp.write ~src_port:53 ~dst:src ~dst_port:src_port udp buffer
358358

359-
let handle_tcp ~t =
359+
let handle_tcp ~t ~close =
360360
(* FIXME: need to record the upstream request *)
361361
let listeners _ =
362362
Log.debug (fun f -> f "DNS TCP handshake complete");
@@ -384,7 +384,10 @@ struct
384384
Lwt.async queries;
385385
loop ()
386386
in
387-
loop ()
387+
Lwt.pick [
388+
loop ();
389+
close
390+
]
388391
in
389392
Some f
390393
in

src/hostnet/hostnet_dns.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ sig
4040
t:t -> udp:Udp.t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> src_port:int ->
4141
Cstruct.t -> (unit, Udp.error) result Lwt.t
4242

43-
val handle_tcp: t:t -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
43+
val handle_tcp: t:t -> close:(unit Lwt.t) -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
4444

4545
val destroy: t -> unit Lwt.t
4646
end

src/hostnet/slirp.ml

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ struct
262262
clock: Clock.t;
263263
mutable pending: Tcp.Id.Set.t;
264264
mutable last_active_time: float;
265+
(* Tasks that will be signalled if the endpoint is destroyed *)
266+
mutable on_destroy: unit Lwt.u Tcp.Id.Map.t;
265267
}
266268
(** A generic TCP/IP endpoint *)
267269

@@ -284,12 +286,17 @@ struct
284286

285287
let pending = Tcp.Id.Set.empty in
286288
let last_active_time = Unix.gettimeofday () in
289+
let on_destroy = Tcp.Id.Map.empty in
287290
let tcp_stack =
288291
{ recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
289-
last_active_time; clock }
292+
last_active_time; clock; on_destroy }
290293
in
291294
Lwt.return tcp_stack
292295

296+
let destroy t =
297+
Tcp.Id.Map.iter (fun _ u -> Lwt.wakeup_later u ()) t.on_destroy;
298+
t.on_destroy <- Tcp.Id.Map.empty
299+
293300
let intercept_tcp_syn t ~id ~syn on_syn_callback (buf: Cstruct.t) =
294301
if syn then begin
295302
if Tcp.Id.Set.mem id t.pending then begin
@@ -300,9 +307,14 @@ struct
300307
Lwt.return_unit
301308
end else begin
302309
t.pending <- Tcp.Id.Set.add id t.pending;
310+
(* Add a task to the "on_destroy" list which will be signalled if
311+
the Endpoint is disconnected from the switch and we should close
312+
connections. *)
313+
let close, close_request = Lwt.task () in
314+
t.on_destroy <- Tcp.Id.Map.add id close_request t.on_destroy;
303315
Lwt.finalize
304316
(fun () ->
305-
on_syn_callback ()
317+
on_syn_callback close
306318
>>= fun listeners ->
307319
let src = Stack_tcp_wire.dst id in
308320
let dst = Stack_tcp_wire.src id in
@@ -324,7 +336,7 @@ struct
324336
Mirage_flow_lwt.Proxy(Clock)(Stack_tcp)(Host.Sockets.Stream.Tcp)
325337

326338
let input_tcp t ~id ~syn (ip, port) (buf: Cstruct.t) =
327-
intercept_tcp_syn t ~id ~syn (fun () ->
339+
intercept_tcp_syn t ~id ~syn (fun close ->
328340
Host.Sockets.Stream.Tcp.connect (ip, port)
329341
>>= function
330342
| Error (`Msg m) ->
@@ -346,9 +358,21 @@ struct
346358
Lwt.return_unit
347359
| Some socket ->
348360
Lwt.finalize (fun () ->
349-
Proxy.proxy t.clock flow socket
361+
Lwt.pick [
362+
Lwt.map
363+
(function Error e -> Error (`Proxy e) | Ok x -> Ok x)
364+
(Proxy.proxy t.clock flow socket);
365+
Lwt.map
366+
(fun () -> Error `Close)
367+
close
368+
]
350369
>>= function
351-
| Error e ->
370+
| Error (`Close) ->
371+
Log.info (fun f ->
372+
f "%s proxy closed due to switch port disconnection"
373+
(Tcp.Flow.to_string tcp));
374+
Lwt.return_unit
375+
| Error (`Proxy e) ->
352376
Log.debug (fun f ->
353377
f "%s proxy failed with %a"
354378
(Tcp.Flow.to_string tcp) Proxy.pp_error e);
@@ -359,6 +383,7 @@ struct
359383
Log.debug (fun f ->
360384
f "closing flow %s" (string_of_id tcp.Tcp.Flow.id));
361385
tcp.Tcp.Flow.socket <- None;
386+
t.on_destroy <- Tcp.Id.Map.remove id t.on_destroy;
362387
Tcp.Flow.remove tcp.Tcp.Flow.id;
363388
Host.Sockets.Stream.Tcp.close socket
364389
)
@@ -484,9 +509,9 @@ struct
484509
let id =
485510
Stack_tcp_wire.v ~src_port:53 ~dst:src ~src:dst ~dst_port:src_port
486511
in
487-
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
512+
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
488513
!dns >>= fun t ->
489-
Dns_forwarder.handle_tcp ~t
514+
Dns_forwarder.handle_tcp ~t ~close
490515
) raw
491516
>|= ok
492517

@@ -808,10 +833,11 @@ struct
808833
let now = Unix.gettimeofday () in
809834
let old_ips = IPMap.fold (fun ip endpoint acc ->
810835
let age = now -. endpoint.Endpoint.last_active_time in
811-
if age > (float_of_int port_max_idle_time) then ip :: acc else acc
836+
if age > (float_of_int port_max_idle_time) then (ip, endpoint) :: acc else acc
812837
) t.endpoints [] in
813-
List.iter (fun ip ->
838+
List.iter (fun (ip, endpoint) ->
814839
Switch.remove t.switch ip;
840+
Endpoint.destroy endpoint;
815841
t.endpoints <- IPMap.remove ip t.endpoints
816842
) old_ips;
817843
Lwt.return_unit

0 commit comments

Comments
 (0)