Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 27 additions & 35 deletions tun/netstack/tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ import (
)

type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
notifyHandle *channel.NotificationHandle
incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
ctx context.Context
cancel context.CancelFunc
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
}

type Net netTun
Expand All @@ -58,20 +58,23 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}

ctx, cancel := context.WithCancel(context.Background())

dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
ctx: ctx,
cancel: cancel,
dnsServers: dnsServers,
mtu: mtu,
}
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
}
dev.notifyHandle = dev.ep.AddNotify(dev)
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
Expand Down Expand Up @@ -121,11 +124,12 @@ func (tun *netTun) Events() <-chan tun.Event {
}

func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
pkb := tun.ep.ReadContext(tun.ctx)
if pkb == nil {
return 0, os.ErrClosed
}

view := pkb.ToView()
pkb.DecRef()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add view.Release

n, err := view.Read(buf[0][offset:])
if err != nil {
return 0, err
Expand All @@ -135,6 +139,10 @@ func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
}

func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
if tun.ctx.Err() != nil {
return 0, os.ErrClosed
}

for _, buf := range buf {
packet := buf[offset:]
if len(packet) == 0 {
Expand All @@ -154,32 +162,16 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
return len(buf), nil
}

func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
if pkt == nil {
return
}

view := pkt.ToView()
pkt.DecRef()

tun.incomingPacket <- view
}

func (tun *netTun) Close() error {
tun.cancel()
tun.stack.RemoveNIC(1)
tun.stack.Close()
tun.ep.RemoveNotify(tun.notifyHandle)
tun.ep.Close()

if tun.events != nil {
close(tun.events)
}

if tun.incomingPacket != nil {
close(tun.incomingPacket)
}

return nil
}

Expand Down