Skip to content

Commit 9ca8b41

Browse files
committed
TUN-9472: Add virtual DNS service
Adds a new reserved service to route UDP requests towards the local DNS resolver. Closes TUN-9472
1 parent b4a98b1 commit 9ca8b41

File tree

6 files changed

+310
-5
lines changed

6 files changed

+310
-5
lines changed

cmd/cloudflared/tunnel/configuration.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
2626
"github.com/cloudflare/cloudflared/features"
2727
"github.com/cloudflare/cloudflared/ingress"
28+
"github.com/cloudflare/cloudflared/ingress/origins"
2829
"github.com/cloudflare/cloudflared/orchestration"
2930
"github.com/cloudflare/cloudflared/supervisor"
3031
"github.com/cloudflare/cloudflared/tlsconfig"
@@ -219,6 +220,8 @@ func prepareTunnelConfig(
219220
resolvedRegion = endpoint
220221
}
221222

223+
dnsService := origins.NewDNSResolver(log)
224+
222225
tunnelConfig := &supervisor.TunnelConfig{
223226
ClientConfig: clientConfig,
224227
GracePeriod: gracePeriod,
@@ -246,6 +249,7 @@ func prepareTunnelConfig(
246249
DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery),
247250
QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit),
248251
QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit),
252+
OriginDNSService: dnsService,
249253
}
250254
icmpRouter, err := newICMPRouter(c, log)
251255
if err != nil {

ingress/origin_udp_proxy.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type UDPOriginService struct {
2222

2323
// UDPOriginProxy provides a UDP dial operation to a requested addr.
2424
type UDPOriginProxy interface {
25-
DialUDP(addr netip.AddrPort) (*net.UDPConn, error)
25+
DialUDP(addr netip.AddrPort) (net.Conn, error)
2626
}
2727

2828
func NewUDPOriginService(reserved map[netip.AddrPort]UDPOriginProxy, logger *zerolog.Logger) *UDPOriginService {
@@ -40,7 +40,7 @@ func (s *UDPOriginService) SetDefaultDialer(dialer UDPOriginProxy) {
4040
}
4141

4242
// DialUDP will perform a dial UDP to the requested addr.
43-
func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (*net.UDPConn, error) {
43+
func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
4444
// Check to see if any reserved services are available for this addr and call their dialer instead.
4545
if dialer, ok := s.reservedServices[addr]; ok {
4646
return dialer.DialUDP(addr)
@@ -52,7 +52,7 @@ type defaultUDPDialer struct{}
5252

5353
var DefaultUDPDialer UDPOriginProxy = &defaultUDPDialer{}
5454

55-
func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (*net.UDPConn, error) {
55+
func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
5656
addr := net.UDPAddrFromAddrPort(dest)
5757

5858
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination

ingress/origins/dns.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package origins
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/netip"
7+
"sync"
8+
"time"
9+
10+
"github.com/rs/zerolog"
11+
12+
"github.com/cloudflare/cloudflared/ingress"
13+
)
14+
15+
const (
16+
// We need a DNS record:
17+
// 1. That will be around for as long as cloudflared is
18+
// 2. That Cloudflare controls: to allow us to make changes if needed
19+
// 3. That is an external record to a typical customer's network: enforcing that the DNS request go to the
20+
// local DNS resolver over any local /etc/host configurations setup.
21+
// 4. That cloudflared would normally query: ensuring that users with a positive security model for DNS queries
22+
// don't need to adjust anything.
23+
//
24+
// This hostname is one that used during the edge discovery process and as such satisfies the above constraints.
25+
defaultLookupHost = "region1.v2.argotunnel.com"
26+
defaultResolverPort uint16 = 53
27+
28+
// We want the refresh time to be short to accommodate DNS resolver changes locally, but not too frequent as to
29+
// shuffle the resolver if multiple are configured.
30+
refreshFreq = 5 * time.Minute
31+
refreshTimeout = 5 * time.Second
32+
)
33+
34+
var (
35+
// Virtual DNS service address
36+
VirtualDNSServiceAddr = netip.AddrPortFrom(netip.MustParseAddr("2606:4700:0cf1:2000:0000:0000:0000:0001"), 53)
37+
38+
defaultResolverAddr = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), defaultResolverPort)
39+
)
40+
41+
type netDial func(network string, address string) (net.Conn, error)
42+
43+
// DNSResolverService will make DNS requests to the local DNS resolver via the Dial method.
44+
type DNSResolverService struct {
45+
address netip.AddrPort
46+
addressM sync.RWMutex
47+
48+
dialer ingress.UDPOriginProxy
49+
resolver peekResolver
50+
logger *zerolog.Logger
51+
}
52+
53+
func NewDNSResolver(logger *zerolog.Logger) *DNSResolverService {
54+
return &DNSResolverService{
55+
address: defaultResolverAddr,
56+
dialer: ingress.DefaultUDPDialer,
57+
resolver: &resolver{dialFunc: net.Dial},
58+
logger: logger,
59+
}
60+
}
61+
62+
func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) {
63+
s.addressM.RLock()
64+
dest := s.address
65+
s.addressM.RUnlock()
66+
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
67+
return s.dialer.DialUDP(dest)
68+
}
69+
70+
// StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if
71+
// adjusted while the cloudflared process is running.
72+
func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) {
73+
// Call update first to load an address before handling traffic
74+
err := s.update(ctx)
75+
if err != nil {
76+
s.logger.Err(err).Msg("Failed to initialize DNS local resolver")
77+
}
78+
for {
79+
select {
80+
case <-ctx.Done():
81+
return
82+
case <-time.Tick(refreshFreq):
83+
err := s.update(ctx)
84+
if err != nil {
85+
s.logger.Err(err).Msg("Failed to refresh DNS local resolver")
86+
}
87+
}
88+
}
89+
}
90+
91+
func (s *DNSResolverService) update(ctx context.Context) error {
92+
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
93+
defer cancel()
94+
// Make a standard DNS request to a well-known DNS record that will last a long time
95+
_, err := s.resolver.lookupNetIP(ctx, defaultLookupHost)
96+
if err != nil {
97+
return err
98+
}
99+
100+
// Validate the address before updating internal reference
101+
_, address := s.resolver.addr()
102+
peekAddrPort, err := netip.ParseAddrPort(address)
103+
if err == nil {
104+
s.setAddress(peekAddrPort)
105+
return nil
106+
}
107+
// It's possible that the address didn't have an attached port, attempt to parse just the address and use
108+
// the default port 53
109+
peekAddr, err := netip.ParseAddr(address)
110+
if err != nil {
111+
return err
112+
}
113+
s.setAddress(netip.AddrPortFrom(peekAddr, defaultResolverPort))
114+
return nil
115+
}
116+
117+
// lock and update the address used for the local DNS resolver
118+
func (s *DNSResolverService) setAddress(addr netip.AddrPort) {
119+
s.addressM.Lock()
120+
defer s.addressM.Unlock()
121+
if s.address != addr {
122+
s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr)
123+
}
124+
s.address = addr
125+
}
126+
127+
type peekResolver interface {
128+
addr() (network string, address string)
129+
lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error)
130+
}
131+
132+
// resolver is a shim that inspects the go runtime's DNS resolution process to capture the DNS resolver
133+
// address used to complete a DNS request.
134+
type resolver struct {
135+
network string
136+
address string
137+
dialFunc netDial
138+
}
139+
140+
func (r *resolver) addr() (network string, address string) {
141+
return r.network, r.address
142+
}
143+
144+
func (r *resolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) {
145+
resolver := &net.Resolver{
146+
PreferGo: true,
147+
// Use the peekDial to inspect the results of the DNS resolver used during the LookupIPAddr call.
148+
Dial: r.peekDial,
149+
}
150+
return resolver.LookupNetIP(ctx, "ip", host)
151+
}
152+
153+
func (r *resolver) peekDial(ctx context.Context, network, address string) (net.Conn, error) {
154+
r.network = network
155+
r.address = address
156+
return r.dialFunc(network, address)
157+
}

ingress/origins/dns_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package origins
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"net/netip"
8+
"testing"
9+
10+
"github.com/rs/zerolog"
11+
)
12+
13+
func TestDNSResolver_DefaultResolver(t *testing.T) {
14+
log := zerolog.Nop()
15+
service := NewDNSResolver(&log)
16+
mockResolver := &mockPeekResolver{
17+
address: "127.0.0.2:53",
18+
}
19+
service.resolver = mockResolver
20+
if service.address != defaultResolverAddr {
21+
t.Errorf("resolver address should be the default: %s, was: %s", defaultResolverAddr, service.address)
22+
}
23+
}
24+
25+
func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
26+
log := zerolog.Nop()
27+
service := NewDNSResolver(&log)
28+
29+
mockResolver := &mockPeekResolver{}
30+
service.resolver = mockResolver
31+
32+
expectedAddr := netip.MustParseAddrPort("127.0.0.2:53")
33+
addresses := []string{
34+
"127.0.0.2:53",
35+
"127.0.0.2", // missing port should be added (even though this is unlikely to happen)
36+
}
37+
38+
for _, addr := range addresses {
39+
mockResolver.address = addr
40+
// Update the resolver address
41+
err := service.update(t.Context())
42+
if err != nil {
43+
t.Error(err)
44+
}
45+
// Validate expected
46+
if service.address != expectedAddr {
47+
t.Errorf("resolver address should be: %s, was: %s", expectedAddr, service.address)
48+
}
49+
}
50+
}
51+
52+
func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
53+
log := zerolog.Nop()
54+
service := NewDNSResolver(&log)
55+
mockResolver := &mockPeekResolver{}
56+
service.resolver = mockResolver
57+
58+
invalidAddresses := []string{
59+
"999.999.999.999",
60+
"localhost",
61+
"255.255.255",
62+
}
63+
64+
for _, addr := range invalidAddresses {
65+
mockResolver.address = addr
66+
// Update the resolver address should not update for these invalid addresses
67+
err := service.update(t.Context())
68+
if err == nil {
69+
t.Error("service update should throw an error")
70+
}
71+
// Validate expected
72+
if service.address != defaultResolverAddr {
73+
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
74+
}
75+
}
76+
}
77+
78+
func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
79+
log := zerolog.Nop()
80+
service := NewDNSResolver(&log)
81+
resolverErr := errors.New("test resolver error")
82+
mockResolver := &mockPeekResolver{err: resolverErr}
83+
service.resolver = mockResolver
84+
85+
// Update the resolver address should not update when the resolver cannot complete the lookup
86+
err := service.update(t.Context())
87+
if err != resolverErr {
88+
t.Error("service update should throw an error")
89+
}
90+
// Validate expected
91+
if service.address != defaultResolverAddr {
92+
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
93+
}
94+
}
95+
96+
func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) {
97+
log := zerolog.Nop()
98+
service := NewDNSResolver(&log)
99+
mockResolver := &mockPeekResolver{}
100+
service.resolver = mockResolver
101+
mockDialer := &mockDialer{expected: defaultResolverAddr}
102+
service.dialer = mockDialer
103+
104+
// Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53
105+
_, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53"))
106+
if err != nil {
107+
t.Error(err)
108+
}
109+
}
110+
111+
type mockPeekResolver struct {
112+
err error
113+
address string
114+
}
115+
116+
func (r *mockPeekResolver) addr() (network, address string) {
117+
return "udp", r.address
118+
}
119+
120+
func (r *mockPeekResolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) {
121+
// We can return an empty result as it doesn't matter as long as the lookup doesn't fail
122+
return []netip.Addr{}, r.err
123+
}
124+
125+
type mockDialer struct {
126+
expected netip.AddrPort
127+
}
128+
129+
func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) {
130+
if d.expected != addr {
131+
return nil, errors.New("unexpected address dialed")
132+
}
133+
return nil, nil
134+
}

supervisor/supervisor.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"net"
7+
"net/netip"
78
"strings"
89
"time"
910

@@ -14,6 +15,7 @@ import (
1415
"github.com/cloudflare/cloudflared/connection"
1516
"github.com/cloudflare/cloudflared/edgediscovery"
1617
"github.com/cloudflare/cloudflared/ingress"
18+
"github.com/cloudflare/cloudflared/ingress/origins"
1719
"github.com/cloudflare/cloudflared/orchestration"
1820
v3 "github.com/cloudflare/cloudflared/quic/v3"
1921
"github.com/cloudflare/cloudflared/retry"
@@ -78,8 +80,11 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
7880
edgeBindAddr := config.EdgeBindAddr
7981

8082
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
81-
// No reserved ingress services for now, hence the nil
82-
ingressUDPService := ingress.NewUDPOriginService(nil, config.Log)
83+
84+
// Setup the reserved virtual origins
85+
reservedServices := map[netip.AddrPort]ingress.UDPOriginProxy{}
86+
reservedServices[origins.VirtualDNSServiceAddr] = config.OriginDNSService
87+
ingressUDPService := ingress.NewUDPOriginService(reservedServices, config.Log)
8388
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingressUDPService, orchestrator.GetFlowLimiter())
8489

8590
edgeTunnelServer := EdgeTunnelServer{
@@ -128,6 +133,9 @@ func (s *Supervisor) Run(
128133
}()
129134
}
130135

136+
// Setup DNS Resolver refresh
137+
go s.config.OriginDNSService.StartRefreshLoop(ctx)
138+
131139
if err := s.initialize(ctx, connectedSignal); err != nil {
132140
if err == errEarlyShutdown {
133141
return nil

0 commit comments

Comments
 (0)