From 4b7ace59a09c894036add7ed35031e494a1d5268 Mon Sep 17 00:00:00 2001 From: JiHwan Oh Date: Fri, 28 Nov 2025 21:14:12 +0900 Subject: [PATCH] refactor(tun): use context for packet handling and cleanup Signed-off-by: JiHwan Oh --- tun/netstack/tun.go | 62 ++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index a7aec9e82..9611dc4e8 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -40,14 +40,14 @@ import ( ) type netTun struct { - ep *channel.Endpoint - stack *stack.Stack - events chan tun.Event - notifyHandle *channel.NotificationHandle - incomingPacket chan *buffer.View - mtu int - dnsServers []netip.Addr - hasV4, hasV6 bool + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + ctx context.Context + cancel context.CancelFunc + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool } type Net netTun @@ -58,20 +58,23 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, HandleLocal: true, } + + ctx, cancel := context.WithCancel(context.Background()) + dev := &netTun{ - ep: channel.New(1024, uint32(mtu), ""), - stack: stack.New(opts), - events: make(chan tun.Event, 10), - incomingPacket: make(chan *buffer.View), - dnsServers: dnsServers, - mtu: mtu, + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 10), + ctx: ctx, + cancel: cancel, + dnsServers: dnsServers, + mtu: mtu, } sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - dev.notifyHandle = dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) @@ -121,11 +124,12 @@ func (tun *netTun) Events() <-chan tun.Event { } func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { - view, ok := <-tun.incomingPacket - if !ok { + pkb := tun.ep.ReadContext(tun.ctx) + if pkb == nil { return 0, os.ErrClosed } - + view := pkb.ToView() + pkb.DecRef() n, err := view.Read(buf[0][offset:]) if err != nil { return 0, err @@ -135,6 +139,10 @@ func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { } func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + if tun.ctx.Err() != nil { + return 0, os.ErrClosed + } + for _, buf := range buf { packet := buf[offset:] if len(packet) == 0 { @@ -154,32 +162,16 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { return len(buf), nil } -func (tun *netTun) WriteNotify() { - pkt := tun.ep.Read() - if pkt == nil { - return - } - - view := pkt.ToView() - pkt.DecRef() - - tun.incomingPacket <- view -} - func (tun *netTun) Close() error { + tun.cancel() tun.stack.RemoveNIC(1) tun.stack.Close() - tun.ep.RemoveNotify(tun.notifyHandle) tun.ep.Close() if tun.events != nil { close(tun.events) } - if tun.incomingPacket != nil { - close(tun.incomingPacket) - } - return nil }