Skip to content

Commit 398da88

Browse files
committed
TUN-9473: Add --dns-resolver-addrs flag
To help support users with environments that don't work well with the DNS local resolver's automatic resolution process for local resolver addresses, we introduce a flag to provide them statically to the runtime. When providing the resolver addresses, cloudflared will no longer lookup the DNS resolver addresses and use the user input directly. When provided with a list of DNS resolvers larger than one, the resolver service will randomly select one at random for each incoming request. Closes TUN-9473
1 parent 70ed7ff commit 398da88

File tree

5 files changed

+148
-37
lines changed

5 files changed

+148
-37
lines changed

cmd/cloudflared/flags/flags.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,7 @@ const (
157157

158158
// ApiURL is the command line flag used to define the base URL of the API
159159
ApiURL = "api-url"
160+
161+
// Virtual DNS resolver service resolver addresses to use instead of dynamically fetching them from the OS.
162+
VirtualDNSServiceResolverAddresses = "dns-resolver-addrs"
160163
)

cmd/cloudflared/tunnel/configuration.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,17 @@ func prepareTunnelConfig(
227227
DefaultDialer: ingress.NewDialer(warpRoutingConfig),
228228
TCPWriteTimeout: c.Duration(flags.WriteStreamTimeout),
229229
}, log)
230+
231+
// Setup DNS Resolver Service
232+
dnsResolverAddrs := c.StringSlice(flags.VirtualDNSServiceResolverAddresses)
230233
dnsService := origins.NewDNSResolverService(origins.NewDNSDialer(), log)
234+
if len(dnsResolverAddrs) > 0 {
235+
addrs, err := parseResolverAddrPorts(dnsResolverAddrs)
236+
if err != nil {
237+
return nil, nil, fmt.Errorf("invalid %s provided: %w", flags.VirtualDNSServiceResolverAddresses, err)
238+
}
239+
dnsService = origins.NewStaticDNSResolverService(addrs, origins.NewDNSDialer(), log)
240+
}
231241
originDialerService.AddReservedService(dnsService, []netip.AddrPort{origins.VirtualDNSServiceAddr})
232242

233243
tunnelConfig := &supervisor.TunnelConfig{
@@ -507,3 +517,19 @@ func findLocalAddr(dst net.IP, port int) (netip.Addr, error) {
507517
localAddr := localAddrPort.Addr()
508518
return localAddr, nil
509519
}
520+
521+
func parseResolverAddrPorts(input []string) ([]netip.AddrPort, error) {
522+
// We don't allow more than 10 resolvers to be provided statically for the resolver service.
523+
if len(input) > 10 {
524+
return nil, errors.New("too many addresses provided, max: 10")
525+
}
526+
addrs := make([]netip.AddrPort, 0, len(input))
527+
for _, val := range input {
528+
addr, err := netip.ParseAddrPort(val)
529+
if err != nil {
530+
return nil, err
531+
}
532+
addrs = append(addrs, addr)
533+
}
534+
return addrs, nil
535+
}

cmd/cloudflared/tunnel/subcommands.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ var (
241241
Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports",
242242
EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"},
243243
}
244+
dnsResolverAddrsFlag = &cli.StringSliceFlag{
245+
Name: flags.VirtualDNSServiceResolverAddresses,
246+
Usage: "Overrides the dynamic DNS resolver resolution to use these address:port's instead.",
247+
EnvVars: []string{"TUNNEL_DNS_RESOLVER_ADDRS"},
248+
}
244249
)
245250

246251
func buildCreateCommand() *cli.Command {
@@ -718,6 +723,7 @@ func buildRunCommand() *cli.Command {
718723
icmpv4SrcFlag,
719724
icmpv6SrcFlag,
720725
maxActiveFlowsFlag,
726+
dnsResolverAddrsFlag,
721727
}
722728
flags = append(flags, configureProxyFlags(false)...)
723729
return &cli.Command{

ingress/origins/dns.go

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package origins
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"math/big"
57
"net"
68
"net/netip"
9+
"slices"
710
"sync"
811
"time"
912

@@ -42,42 +45,50 @@ type netDial func(network string, address string) (net.Conn, error)
4245

4346
// DNSResolverService will make DNS requests to the local DNS resolver via the Dial method.
4447
type DNSResolverService struct {
45-
address netip.AddrPort
46-
addressM sync.RWMutex
47-
48-
dialer ingress.OriginDialer
49-
resolver peekResolver
50-
logger *zerolog.Logger
48+
addresses []netip.AddrPort
49+
addressesM sync.RWMutex
50+
static bool
51+
dialer ingress.OriginDialer
52+
resolver peekResolver
53+
logger *zerolog.Logger
5154
}
5255

5356
func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService {
5457
return &DNSResolverService{
55-
address: defaultResolverAddr,
56-
dialer: dialer,
57-
resolver: &resolver{dialFunc: net.Dial},
58-
logger: logger,
58+
addresses: []netip.AddrPort{defaultResolverAddr},
59+
dialer: dialer,
60+
resolver: &resolver{dialFunc: net.Dial},
61+
logger: logger,
5962
}
6063
}
6164

65+
func NewStaticDNSResolverService(resolverAddrs []netip.AddrPort, dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService {
66+
s := NewDNSResolverService(dialer, logger)
67+
s.addresses = resolverAddrs
68+
s.static = true
69+
return s
70+
}
71+
6272
func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) {
63-
s.addressM.RLock()
64-
dest := s.address
65-
s.addressM.RUnlock()
73+
dest := s.getAddress()
6674
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
6775
return s.dialer.DialTCP(ctx, dest)
6876
}
6977

7078
func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) {
71-
s.addressM.RLock()
72-
dest := s.address
73-
s.addressM.RUnlock()
79+
dest := s.getAddress()
7480
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
7581
return s.dialer.DialUDP(dest)
7682
}
7783

7884
// StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if
7985
// adjusted while the cloudflared process is running.
86+
// Does not run when the resolver was provided with external resolver addresses via CLI.
8087
func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) {
88+
if s.static {
89+
s.logger.Debug().Msgf("Canceled DNS local resolver refresh loop because static resolver addresses were provided: %s", s.addresses)
90+
return
91+
}
8192
// Call update first to load an address before handling traffic
8293
err := s.update(ctx)
8394
if err != nil {
@@ -122,14 +133,38 @@ func (s *DNSResolverService) update(ctx context.Context) error {
122133
return nil
123134
}
124135

136+
// returns the address from the peekResolver or from the static addresses if provided.
137+
// If multiple addresses are provided in the static addresses pick one randomly.
138+
func (s *DNSResolverService) getAddress() netip.AddrPort {
139+
s.addressesM.RLock()
140+
defer s.addressesM.RUnlock()
141+
l := len(s.addresses)
142+
if l <= 0 {
143+
return defaultResolverAddr
144+
}
145+
if l == 1 {
146+
return s.addresses[0]
147+
}
148+
// Only initialize the random selection if there is more than one element in the list.
149+
var i int64 = 0
150+
r, err := rand.Int(rand.Reader, big.NewInt(int64(l)))
151+
// We ignore errors from crypto rand and use index 0; this should be extremely unlikely and the
152+
// list index doesn't need to be cryptographically secure, but linters insist.
153+
if err == nil {
154+
i = r.Int64()
155+
}
156+
return s.addresses[i]
157+
}
158+
125159
// lock and update the address used for the local DNS resolver
126160
func (s *DNSResolverService) setAddress(addr netip.AddrPort) {
127-
s.addressM.Lock()
128-
defer s.addressM.Unlock()
129-
if s.address != addr {
161+
s.addressesM.Lock()
162+
defer s.addressesM.Unlock()
163+
if !slices.Contains(s.addresses, addr) {
130164
s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr)
131165
}
132-
s.address = addr
166+
// We only store one address when reading the peekResolver, so we just replace the whole list.
167+
s.addresses = []netip.AddrPort{addr}
133168
}
134169

135170
type peekResolver interface {

ingress/origins/dns_test.go

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"errors"
66
"net"
77
"net/netip"
8+
"slices"
89
"testing"
10+
"time"
911

1012
"github.com/rs/zerolog"
1113
)
@@ -17,9 +19,18 @@ func TestDNSResolver_DefaultResolver(t *testing.T) {
1719
address: "127.0.0.2:53",
1820
}
1921
service.resolver = mockResolver
20-
if service.address != defaultResolverAddr {
21-
t.Errorf("resolver address should be the default: %s, was: %s", defaultResolverAddr, service.address)
22+
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
23+
}
24+
25+
func TestStaticDNSResolver_DefaultResolver(t *testing.T) {
26+
log := zerolog.Nop()
27+
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
28+
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
29+
mockResolver := &mockPeekResolver{
30+
address: "127.0.0.2:53",
2231
}
32+
service.resolver = mockResolver
33+
validateAddrs(t, addresses, service.addresses)
2334
}
2435

2536
func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
@@ -29,24 +40,47 @@ func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
2940
mockResolver := &mockPeekResolver{}
3041
service.resolver = mockResolver
3142

32-
expectedAddr := netip.MustParseAddrPort("127.0.0.2:53")
33-
addresses := []string{
34-
"127.0.0.2:53",
35-
"127.0.0.2", // missing port should be added (even though this is unlikely to happen)
43+
tests := []struct {
44+
addr string
45+
expected netip.AddrPort
46+
}{
47+
{"127.0.0.2:53", netip.MustParseAddrPort("127.0.0.2:53")},
48+
// missing port should be added (even though this is unlikely to happen)
49+
{"127.0.0.3", netip.MustParseAddrPort("127.0.0.3:53")},
3650
}
3751

38-
for _, addr := range addresses {
39-
mockResolver.address = addr
52+
for _, test := range tests {
53+
mockResolver.address = test.addr
4054
// Update the resolver address
4155
err := service.update(t.Context())
4256
if err != nil {
4357
t.Error(err)
4458
}
4559
// Validate expected
46-
if service.address != expectedAddr {
47-
t.Errorf("resolver address should be: %s, was: %s", expectedAddr, service.address)
48-
}
60+
validateAddrs(t, []netip.AddrPort{test.expected}, service.addresses)
61+
}
62+
}
63+
64+
func TestStaticDNSResolver_RefreshLoopExits(t *testing.T) {
65+
log := zerolog.Nop()
66+
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
67+
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
68+
69+
mockResolver := &mockPeekResolver{
70+
address: "127.0.0.2:53",
4971
}
72+
service.resolver = mockResolver
73+
74+
ctx, cancel := context.WithCancel(t.Context())
75+
defer cancel()
76+
77+
go service.StartRefreshLoop(ctx)
78+
79+
// Wait for the refresh loop to end _and_ not update the addresses
80+
time.Sleep(10 * time.Millisecond)
81+
82+
// Validate expected
83+
validateAddrs(t, addresses, service.addresses)
5084
}
5185

5286
func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
@@ -69,9 +103,7 @@ func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
69103
t.Error("service update should throw an error")
70104
}
71105
// Validate expected
72-
if service.address != defaultResolverAddr {
73-
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
74-
}
106+
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
75107
}
76108
}
77109

@@ -88,9 +120,7 @@ func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
88120
t.Error("service update should throw an error")
89121
}
90122
// Validate expected
91-
if service.address != defaultResolverAddr {
92-
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
93-
}
123+
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
94124
}
95125

96126
func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) {
@@ -152,3 +182,14 @@ func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) {
152182
}
153183
return nil, nil
154184
}
185+
186+
func validateAddrs(t *testing.T, expected []netip.AddrPort, actual []netip.AddrPort) {
187+
if len(actual) != len(expected) {
188+
t.Errorf("addresses should only contain one element: %s", actual)
189+
}
190+
for _, e := range expected {
191+
if !slices.Contains(actual, e) {
192+
t.Errorf("missing address: %s in %s", e, actual)
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)