From 157e54f3d9c00f13c29b245723903b087e4f9fcd Mon Sep 17 00:00:00 2001 From: kunsonxs Date: Sun, 26 Nov 2023 15:59:28 +0800 Subject: [PATCH 1/3] feat : optimize wireguard client kernel mode impls --- .github/docker/Dockerfile | 2 +- go.mod | 5 + go.sum | 15 +- infra/conf/wireguard.go | 26 +- proxy/wireguard/client.go | 9 +- proxy/wireguard/config.go | 2 +- proxy/wireguard/iptables/README.md | 12 + proxy/wireguard/iptables/errors/doc.go | 18 + proxy/wireguard/iptables/errors/errors.go | 249 +++++ .../wireguard/iptables/errors/errors_test.go | 530 +++++++++++ proxy/wireguard/iptables/exec/README.md | 5 + proxy/wireguard/iptables/exec/exec.go | 256 +++++ proxy/wireguard/iptables/exec/fixup_go118.go | 32 + proxy/wireguard/iptables/exec/fixup_go119.go | 40 + proxy/wireguard/iptables/iptables.go | 874 ++++++++++++++++++ proxy/wireguard/iptables/iptables_linux.go | 101 ++ .../iptables/iptables_unsupported.go | 33 + proxy/wireguard/iptables/save_restore.go | 52 ++ proxy/wireguard/iptables/sets/byte.go | 137 +++ proxy/wireguard/iptables/sets/doc.go | 19 + proxy/wireguard/iptables/sets/empty.go | 21 + proxy/wireguard/iptables/sets/int.go | 137 +++ proxy/wireguard/iptables/sets/int32.go | 137 +++ proxy/wireguard/iptables/sets/int64.go | 137 +++ proxy/wireguard/iptables/sets/ordered.go | 53 ++ proxy/wireguard/iptables/sets/set.go | 241 +++++ proxy/wireguard/iptables/sets/string.go | 137 +++ proxy/wireguard/iptables/version/doc.go | 18 + proxy/wireguard/iptables/version/version.go | 371 ++++++++ .../iptables/version/version_test.go | 453 +++++++++ proxy/wireguard/iptables/wait/backoff.go | 500 ++++++++++ proxy/wireguard/iptables/wait/clock/clock.go | 168 ++++ proxy/wireguard/iptables/wait/delay.go | 51 + proxy/wireguard/iptables/wait/error.go | 96 ++ proxy/wireguard/iptables/wait/loop.go | 94 ++ proxy/wireguard/iptables/wait/poll.go | 315 +++++++ proxy/wireguard/iptables/wait/timer.go | 121 +++ proxy/wireguard/iptables/wait/wait.go | 222 +++++ .../{tun_default.go => tun_kernel_default.go} | 8 +- .../{tun_linux.go => tun_kernel_linux.go} | 96 +- proxy/wireguard/wireguard_linux.go | 61 ++ proxy/wireguard/wireguard_others.go | 27 + 42 files changed, 5808 insertions(+), 73 deletions(-) create mode 100644 proxy/wireguard/iptables/README.md create mode 100644 proxy/wireguard/iptables/errors/doc.go create mode 100644 proxy/wireguard/iptables/errors/errors.go create mode 100644 proxy/wireguard/iptables/errors/errors_test.go create mode 100644 proxy/wireguard/iptables/exec/README.md create mode 100644 proxy/wireguard/iptables/exec/exec.go create mode 100644 proxy/wireguard/iptables/exec/fixup_go118.go create mode 100644 proxy/wireguard/iptables/exec/fixup_go119.go create mode 100644 proxy/wireguard/iptables/iptables.go create mode 100644 proxy/wireguard/iptables/iptables_linux.go create mode 100644 proxy/wireguard/iptables/iptables_unsupported.go create mode 100644 proxy/wireguard/iptables/save_restore.go create mode 100644 proxy/wireguard/iptables/sets/byte.go create mode 100644 proxy/wireguard/iptables/sets/doc.go create mode 100644 proxy/wireguard/iptables/sets/empty.go create mode 100644 proxy/wireguard/iptables/sets/int.go create mode 100644 proxy/wireguard/iptables/sets/int32.go create mode 100644 proxy/wireguard/iptables/sets/int64.go create mode 100644 proxy/wireguard/iptables/sets/ordered.go create mode 100644 proxy/wireguard/iptables/sets/set.go create mode 100644 proxy/wireguard/iptables/sets/string.go create mode 100644 proxy/wireguard/iptables/version/doc.go create mode 100644 proxy/wireguard/iptables/version/version.go create mode 100644 proxy/wireguard/iptables/version/version_test.go create mode 100644 proxy/wireguard/iptables/wait/backoff.go create mode 100644 proxy/wireguard/iptables/wait/clock/clock.go create mode 100644 proxy/wireguard/iptables/wait/delay.go create mode 100644 proxy/wireguard/iptables/wait/error.go create mode 100644 proxy/wireguard/iptables/wait/loop.go create mode 100644 proxy/wireguard/iptables/wait/poll.go create mode 100644 proxy/wireguard/iptables/wait/timer.go create mode 100644 proxy/wireguard/iptables/wait/wait.go rename proxy/wireguard/{tun_default.go => tun_kernel_default.go} (58%) rename proxy/wireguard/{tun_linux.go => tun_kernel_linux.go} (73%) create mode 100644 proxy/wireguard/wireguard_linux.go create mode 100644 proxy/wireguard/wireguard_others.go diff --git a/.github/docker/Dockerfile b/.github/docker/Dockerfile index ad1e8c3dc214..8ebc6f0f64a2 100644 --- a/.github/docker/Dockerfile +++ b/.github/docker/Dockerfile @@ -10,7 +10,7 @@ WORKDIR /root COPY .github/docker/files/config.json /etc/xray/config.json COPY --from=build /src/xray /usr/bin/xray RUN set -ex \ - && apk add --no-cache tzdata ca-certificates \ + && apk add --no-cache tzdata ca-certificates iptables \ && mkdir -p /var/log/xray /usr/share/xray \ && chmod +x /usr/bin/xray \ && wget -O /usr/share/xray/geosite.dat https://github.com/Loyalsoldier/v2ray-rules-dat/releases/latest/download/geosite.dat \ diff --git a/go.mod b/go.mod index c0e7bfbef15f..14acc6c1b0be 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( google.golang.org/protobuf v1.31.0 gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b h12.io/socks v1.0.3 + kernel.org/pub/linux/libs/security/libcap/cap v1.2.69 lukechampine.com/blake3 v1.2.1 ) @@ -45,6 +46,8 @@ require ( github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a // indirect github.com/klauspost/compress v1.17.2 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/onsi/ginkgo/v2 v2.13.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect @@ -58,6 +61,8 @@ require ( golang.org/x/tools v0.15.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + kernel.org/pub/linux/libs/security/libcap/psx v1.2.69 // indirect ) diff --git a/go.sum b/go.sum index 1239d8f00e20..ba51a6acc619 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cloudflare/circl v1.3.6 h1:/xbKIqSHbZXHwkhbrhrt2YOHIwYJlXH94E3tI/gDlUg= github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -82,12 +83,12 @@ github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -98,6 +99,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo/v2 v2.13.1 h1:LNGfMbR2OVGBfXjvRZIZ2YCTQdGKtPLvuI1rMCCj3OU= github.com/onsi/ginkgo/v2 v2.13.1/go.mod h1:XStQ8QcGwLyF4HdfcZB8SFOS/MWCgDuXMSBe6zrvLgM= github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= @@ -287,8 +290,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -306,6 +309,10 @@ h12.io/socks v1.0.3/go.mod h1:AIhxy1jOId/XCz9BO+EIgNL2rQiPTBNnOfnVnQ+3Eck= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +kernel.org/pub/linux/libs/security/libcap/cap v1.2.69 h1:N0m3tKYbkRMmDobh/47ngz+AWeV7PcfXMDi8xu3Vrag= +kernel.org/pub/linux/libs/security/libcap/cap v1.2.69/go.mod h1:Tk5Ip2TuxaWGpccL7//rAsLRH6RQ/jfqTGxuN/+i/FQ= +kernel.org/pub/linux/libs/security/libcap/psx v1.2.69 h1:IdrOs1ZgwGw5CI+BH6GgVVlOt+LAXoPyh7enr8lfaXs= +kernel.org/pub/linux/libs/security/libcap/psx v1.2.69/go.mod h1:+l6Ee2F59XiJ2I6WR5ObpC1utCQJZ/VLsEbQCD8RG24= lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI= lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index a4f0eda6e2c2..4a189b54e93d 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -116,19 +116,25 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { return nil, newError("unsupported domain strategy: ", c.DomainStrategy) } + // check device exist for wireguard setup + // module "golang.zx2c4.com/wireguard" only support linux and require /dev/net/tun + if wireguard.IsLinux() && !wireguard.CheckUnixKernelTunDeviceEnabled() { + return nil, newError("wireguard module require device /dev/net/tun") + } + config.IsClient = c.IsClient - if c.KernelMode != nil { - config.KernelMode = *c.KernelMode - if config.KernelMode && !wireguard.KernelTunSupported() { - newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog() - } - } else { - config.KernelMode = wireguard.KernelTunSupported() - if config.KernelMode { - newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog() + if c.IsClient { + if support := wireguard.CheckUnixKernelTunSupported(); c.KernelMode == nil { + config.KernelMode = support + } else if *c.KernelMode && support { + config.KernelMode = true + } else { + config.KernelMode = false } } - + if !c.IsClient { + config.KernelMode = false + } return config, nil } diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index def078783523..e007a3fae19a 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -79,7 +79,6 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { h.wgLock.Lock() defer h.wgLock.Unlock() - if h.bind != nil && h.bind.dialer == dialer && h.net != nil { return nil } @@ -127,6 +126,10 @@ func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + if err := h.processWireGuard(dialer); err != nil { + return err + } + outbound := session.OutboundFromContext(ctx) if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified") @@ -137,10 +140,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte inbound.SetCanSpliceCopy(3) } - if err := h.processWireGuard(dialer); err != nil { - return err - } - // Destination of the inner request. destination := outbound.Target command := protocol.RequestCommandTCP diff --git a/proxy/wireguard/config.go b/proxy/wireguard/config.go index 2a316cdd7df8..9ddf36e55a7a 100644 --- a/proxy/wireguard/config.go +++ b/proxy/wireguard/config.go @@ -25,7 +25,7 @@ func (c *DeviceConfig) fallbackIP6() bool { } func (c *DeviceConfig) createTun() tunCreator { - if c.KernelMode { + if c.IsClient && c.KernelMode { return createKernelTun } return createGVisorTun diff --git a/proxy/wireguard/iptables/README.md b/proxy/wireguard/iptables/README.md new file mode 100644 index 000000000000..3bc26ba17f81 --- /dev/null +++ b/proxy/wireguard/iptables/README.md @@ -0,0 +1,12 @@ +# kubernetes iptables + +source code from: + +| package | from | repo | +|------------------|-------------------------------------|---------------------| +| iptables/errors | k8s.io/apimachinery/pkg/util/errors | k8s.io/apimachinery | +| iptables/exec | k8s.io/utils/exec | k8s.io/utils | +| iptables/sets | k8s.io/apimachinery/pkg/util/sets | k8s.io/apimachinery | +| iptables/version | k8s.io/apimachinery/pkg/version | k8s.io/apimachinery | +| iptables/wait | k8s.io/apimachinery/pkg/util/wait | k8s.io/apimachinery | +| iptables | k8s.io/kubernetes/pkg/util/iptables | k8s.io/kubernetes | \ No newline at end of file diff --git a/proxy/wireguard/iptables/errors/doc.go b/proxy/wireguard/iptables/errors/doc.go new file mode 100644 index 000000000000..5d4d6250a316 --- /dev/null +++ b/proxy/wireguard/iptables/errors/doc.go @@ -0,0 +1,18 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package errors implements various utility functions and types around errors. +package errors // import "k8s.io/apimachinery/pkg/util/errors" diff --git a/proxy/wireguard/iptables/errors/errors.go b/proxy/wireguard/iptables/errors/errors.go new file mode 100644 index 000000000000..b5d3aefadcae --- /dev/null +++ b/proxy/wireguard/iptables/errors/errors.go @@ -0,0 +1,249 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package errors + +import ( + "errors" + "fmt" + + "github.com/xtls/xray-core/proxy/wireguard/iptables/sets" +) + +// MessageCountMap contains occurrence for each error message. +type MessageCountMap map[string]int + +// Aggregate represents an object that contains multiple errors, but does not +// necessarily have singular semantic meaning. +// The aggregate can be used with `errors.Is()` to check for the occurrence of +// a specific error type. +// Errors.As() is not supported, because the caller presumably cares about a +// specific error of potentially multiple that match the given type. +type Aggregate interface { + error + Errors() []error + Is(error) bool +} + +// NewAggregate converts a slice of errors into an Aggregate interface, which +// is itself an implementation of the error interface. If the slice is empty, +// this returns nil. +// It will check if any of the element of input error list is nil, to avoid +// nil pointer panic when call Error(). +func NewAggregate(errlist []error) Aggregate { + if len(errlist) == 0 { + return nil + } + // In case of input error list contains nil + var errs []error + for _, e := range errlist { + if e != nil { + errs = append(errs, e) + } + } + if len(errs) == 0 { + return nil + } + return aggregate(errs) +} + +// This helper implements the error and Errors interfaces. Keeping it private +// prevents people from making an aggregate of 0 errors, which is not +// an error, but does satisfy the error interface. +type aggregate []error + +// Error is part of the error interface. +func (agg aggregate) Error() string { + if len(agg) == 0 { + // This should never happen, really. + return "" + } + if len(agg) == 1 { + return agg[0].Error() + } + seenerrs := sets.NewString() + result := "" + agg.visit(func(err error) bool { + msg := err.Error() + if seenerrs.Has(msg) { + return false + } + seenerrs.Insert(msg) + if len(seenerrs) > 1 { + result += ", " + } + result += msg + return false + }) + if len(seenerrs) == 1 { + return result + } + return "[" + result + "]" +} + +func (agg aggregate) Is(target error) bool { + return agg.visit(func(err error) bool { + return errors.Is(err, target) + }) +} + +func (agg aggregate) visit(f func(err error) bool) bool { + for _, err := range agg { + switch err := err.(type) { + case aggregate: + if match := err.visit(f); match { + return match + } + case Aggregate: + for _, nestedErr := range err.Errors() { + if match := f(nestedErr); match { + return match + } + } + default: + if match := f(err); match { + return match + } + } + } + + return false +} + +// Errors is part of the Aggregate interface. +func (agg aggregate) Errors() []error { + return []error(agg) +} + +// Matcher is used to match errors. Returns true if the error matches. +type Matcher func(error) bool + +// FilterOut removes all errors that match any of the matchers from the input +// error. If the input is a singular error, only that error is tested. If the +// input implements the Aggregate interface, the list of errors will be +// processed recursively. +// +// This can be used, for example, to remove known-OK errors (such as io.EOF or +// os.PathNotFound) from a list of errors. +func FilterOut(err error, fns ...Matcher) error { + if err == nil { + return nil + } + if agg, ok := err.(Aggregate); ok { + return NewAggregate(filterErrors(agg.Errors(), fns...)) + } + if !matchesError(err, fns...) { + return err + } + return nil +} + +// matchesError returns true if any Matcher returns true +func matchesError(err error, fns ...Matcher) bool { + for _, fn := range fns { + if fn(err) { + return true + } + } + return false +} + +// filterErrors returns any errors (or nested errors, if the list contains +// nested Errors) for which all fns return false. If no errors +// remain a nil list is returned. The resulting slice will have all +// nested slices flattened as a side effect. +func filterErrors(list []error, fns ...Matcher) []error { + result := []error{} + for _, err := range list { + r := FilterOut(err, fns...) + if r != nil { + result = append(result, r) + } + } + return result +} + +// Flatten takes an Aggregate, which may hold other Aggregates in arbitrary +// nesting, and flattens them all into a single Aggregate, recursively. +func Flatten(agg Aggregate) Aggregate { + result := []error{} + if agg == nil { + return nil + } + for _, err := range agg.Errors() { + if a, ok := err.(Aggregate); ok { + r := Flatten(a) + if r != nil { + result = append(result, r.Errors()...) + } + } else { + if err != nil { + result = append(result, err) + } + } + } + return NewAggregate(result) +} + +// CreateAggregateFromMessageCountMap converts MessageCountMap Aggregate +func CreateAggregateFromMessageCountMap(m MessageCountMap) Aggregate { + if m == nil { + return nil + } + result := make([]error, 0, len(m)) + for errStr, count := range m { + var countStr string + if count > 1 { + countStr = fmt.Sprintf(" (repeated %v times)", count) + } + result = append(result, fmt.Errorf("%v%v", errStr, countStr)) + } + return NewAggregate(result) +} + +// Reduce will return err or nil, if err is an Aggregate and only has one item, +// the first item in the aggregate. +func Reduce(err error) error { + if agg, ok := err.(Aggregate); ok && err != nil { + switch len(agg.Errors()) { + case 1: + return agg.Errors()[0] + case 0: + return nil + } + } + return err +} + +// AggregateGoroutines runs the provided functions in parallel, stuffing all +// non-nil errors into the returned Aggregate. +// Returns nil if all the functions complete successfully. +func AggregateGoroutines(funcs ...func() error) Aggregate { + errChan := make(chan error, len(funcs)) + for _, f := range funcs { + go func(f func() error) { errChan <- f() }(f) + } + errs := make([]error, 0) + for i := 0; i < cap(errChan); i++ { + if err := <-errChan; err != nil { + errs = append(errs, err) + } + } + return NewAggregate(errs) +} + +// ErrPreconditionViolated is returned when the precondition is violated +var ErrPreconditionViolated = errors.New("precondition is violated") diff --git a/proxy/wireguard/iptables/errors/errors_test.go b/proxy/wireguard/iptables/errors/errors_test.go new file mode 100644 index 000000000000..6659fbc6b359 --- /dev/null +++ b/proxy/wireguard/iptables/errors/errors_test.go @@ -0,0 +1,530 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package errors + +import ( + "errors" + "fmt" + "reflect" + "sort" + "testing" +) + +func TestEmptyAggregate(t *testing.T) { + var slice []error + var agg Aggregate + var err error + + agg = NewAggregate(slice) + if agg != nil { + t.Errorf("expected nil, got %#v", agg) + } + err = NewAggregate(slice) + if err != nil { + t.Errorf("expected nil, got %#v", err) + } + + // This is not normally possible, but pedantry demands I test it. + agg = aggregate(slice) // empty aggregate + if s := agg.Error(); s != "" { + t.Errorf("expected empty string, got %q", s) + } + if s := agg.Errors(); len(s) != 0 { + t.Errorf("expected empty slice, got %#v", s) + } + err = agg.(error) + if s := err.Error(); s != "" { + t.Errorf("expected empty string, got %q", s) + } +} + +func TestAggregateWithNil(t *testing.T) { + var slice []error + slice = []error{nil} + var agg Aggregate + var err error + + agg = NewAggregate(slice) + if agg != nil { + t.Errorf("expected nil, got %#v", agg) + } + err = NewAggregate(slice) + if err != nil { + t.Errorf("expected nil, got %#v", err) + } + + // Append a non-nil error + slice = append(slice, fmt.Errorf("err")) + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } + if s := agg.Errors(); len(s) != 1 { + t.Errorf("expected one-element slice, got %#v", s) + } + if s := agg.Errors()[0].Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } + + err = agg.(error) + if err == nil { + t.Errorf("expected non-nil") + } + if s := err.Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } +} + +func TestSingularAggregate(t *testing.T) { + slice := []error{fmt.Errorf("err")} + var agg Aggregate + var err error + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } + if s := agg.Errors(); len(s) != 1 { + t.Errorf("expected one-element slice, got %#v", s) + } + if s := agg.Errors()[0].Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } + + err = agg.(error) + if err == nil { + t.Errorf("expected non-nil") + } + if s := err.Error(); s != "err" { + t.Errorf("expected 'err', got %q", s) + } +} + +func TestPluralAggregate(t *testing.T) { + slice := []error{fmt.Errorf("abc"), fmt.Errorf("123")} + var agg Aggregate + var err error + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "[abc, 123]" { + t.Errorf("expected '[abc, 123]', got %q", s) + } + if s := agg.Errors(); len(s) != 2 { + t.Errorf("expected two-elements slice, got %#v", s) + } + if s := agg.Errors()[0].Error(); s != "abc" { + t.Errorf("expected '[abc, 123]', got %q", s) + } + + err = agg.(error) + if err == nil { + t.Errorf("expected non-nil") + } + if s := err.Error(); s != "[abc, 123]" { + t.Errorf("expected '[abc, 123]', got %q", s) + } +} + +func TestDedupeAggregate(t *testing.T) { + slice := []error{fmt.Errorf("abc"), fmt.Errorf("abc")} + var agg Aggregate + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "abc" { + t.Errorf("expected 'abc', got %q", s) + } + if s := agg.Errors(); len(s) != 2 { + t.Errorf("expected two-elements slice, got %#v", s) + } +} + +func TestDedupePluralAggregate(t *testing.T) { + slice := []error{fmt.Errorf("abc"), fmt.Errorf("abc"), fmt.Errorf("123")} + var agg Aggregate + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "[abc, 123]" { + t.Errorf("expected '[abc, 123]', got %q", s) + } + if s := agg.Errors(); len(s) != 3 { + t.Errorf("expected three-elements slice, got %#v", s) + } +} + +func TestFlattenAndDedupeAggregate(t *testing.T) { + slice := []error{fmt.Errorf("abc"), fmt.Errorf("abc"), NewAggregate([]error{fmt.Errorf("abc")})} + var agg Aggregate + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "abc" { + t.Errorf("expected 'abc', got %q", s) + } + if s := agg.Errors(); len(s) != 3 { + t.Errorf("expected three-elements slice, got %#v", s) + } +} + +func TestFlattenAggregate(t *testing.T) { + slice := []error{fmt.Errorf("abc"), fmt.Errorf("abc"), NewAggregate([]error{fmt.Errorf("abc"), fmt.Errorf("def"), NewAggregate([]error{fmt.Errorf("def"), fmt.Errorf("ghi")})})} + var agg Aggregate + + agg = NewAggregate(slice) + if agg == nil { + t.Errorf("expected non-nil") + } + if s := agg.Error(); s != "[abc, def, ghi]" { + t.Errorf("expected '[abc, def, ghi]', got %q", s) + } + if s := agg.Errors(); len(s) != 3 { + t.Errorf("expected three-elements slice, got %#v", s) + } +} + +func TestFilterOut(t *testing.T) { + testCases := []struct { + err error + filter []Matcher + expected error + }{ + { + nil, + []Matcher{}, + nil, + }, + { + aggregate{}, + []Matcher{}, + nil, + }, + { + aggregate{fmt.Errorf("abc")}, + []Matcher{}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{fmt.Errorf("abc")}, + []Matcher{func(err error) bool { return false }}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{fmt.Errorf("abc")}, + []Matcher{func(err error) bool { return true }}, + nil, + }, + { + aggregate{fmt.Errorf("abc")}, + []Matcher{func(err error) bool { return false }, func(err error) bool { return false }}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{fmt.Errorf("abc")}, + []Matcher{func(err error) bool { return false }, func(err error) bool { return true }}, + nil, + }, + { + aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")}, + []Matcher{func(err error) bool { return err.Error() == "def" }}, + aggregate{fmt.Errorf("abc"), fmt.Errorf("ghi")}, + }, + { + aggregate{aggregate{fmt.Errorf("abc")}}, + []Matcher{}, + aggregate{aggregate{fmt.Errorf("abc")}}, + }, + { + aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}}, + []Matcher{}, + aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}}, + }, + { + aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}}, + []Matcher{func(err error) bool { return err.Error() == "def" }}, + aggregate{aggregate{fmt.Errorf("abc")}}, + }, + } + for i, testCase := range testCases { + err := FilterOut(testCase.err, testCase.filter...) + if !reflect.DeepEqual(testCase.expected, err) { + t.Errorf("%d: expected %v, got %v", i, testCase.expected, err) + } + } +} + +func TestFlatten(t *testing.T) { + testCases := []struct { + agg Aggregate + expected Aggregate + }{ + { + nil, + nil, + }, + { + aggregate{}, + nil, + }, + { + aggregate{fmt.Errorf("abc")}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")}, + aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")}, + }, + { + aggregate{aggregate{fmt.Errorf("abc")}}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{aggregate{aggregate{fmt.Errorf("abc")}}}, + aggregate{fmt.Errorf("abc")}, + }, + { + aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}}, + aggregate{fmt.Errorf("abc"), fmt.Errorf("def")}, + }, + { + aggregate{aggregate{aggregate{fmt.Errorf("abc")}, fmt.Errorf("def"), aggregate{fmt.Errorf("ghi")}}}, + aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")}, + }, + } + for i, testCase := range testCases { + agg := Flatten(testCase.agg) + if !reflect.DeepEqual(testCase.expected, agg) { + t.Errorf("%d: expected %v, got %v", i, testCase.expected, agg) + } + } +} + +func TestCreateAggregateFromMessageCountMap(t *testing.T) { + testCases := []struct { + name string + mcm MessageCountMap + expected Aggregate + }{ + { + "input has single instance of one message", + MessageCountMap{"abc": 1}, + aggregate{fmt.Errorf("abc")}, + }, + { + "input has multiple messages", + MessageCountMap{"abc": 2, "ghi": 1}, + aggregate{fmt.Errorf("abc (repeated 2 times)"), fmt.Errorf("ghi")}, + }, + { + "input has multiple messages", + MessageCountMap{"ghi": 1, "abc": 2}, + aggregate{fmt.Errorf("abc (repeated 2 times)"), fmt.Errorf("ghi")}, + }, + } + + var expected, agg []error + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + if testCase.expected != nil { + expected = testCase.expected.Errors() + sort.Slice(expected, func(i, j int) bool { return expected[i].Error() < expected[j].Error() }) + } + if testCase.mcm != nil { + agg = CreateAggregateFromMessageCountMap(testCase.mcm).Errors() + sort.Slice(agg, func(i, j int) bool { return agg[i].Error() < agg[j].Error() }) + } + if !reflect.DeepEqual(expected, agg) { + t.Errorf("expected %v, got %v", expected, agg) + } + }) + } +} + +func TestAggregateGoroutines(t *testing.T) { + testCases := []struct { + errs []error + expected map[string]bool // can't compare directly to Aggregate due to non-deterministic ordering + }{ + { + []error{}, + nil, + }, + { + []error{nil}, + nil, + }, + { + []error{nil, nil}, + nil, + }, + { + []error{fmt.Errorf("1")}, + map[string]bool{"1": true}, + }, + { + []error{fmt.Errorf("1"), nil}, + map[string]bool{"1": true}, + }, + { + []error{fmt.Errorf("1"), fmt.Errorf("267")}, + map[string]bool{"1": true, "267": true}, + }, + { + []error{fmt.Errorf("1"), nil, fmt.Errorf("1234")}, + map[string]bool{"1": true, "1234": true}, + }, + { + []error{nil, fmt.Errorf("1"), nil, fmt.Errorf("1234"), fmt.Errorf("22")}, + map[string]bool{"1": true, "1234": true, "22": true}, + }, + } + for i, testCase := range testCases { + funcs := make([]func() error, len(testCase.errs)) + for i := range testCase.errs { + err := testCase.errs[i] + funcs[i] = func() error { return err } + } + agg := AggregateGoroutines(funcs...) + if agg == nil { + if len(testCase.expected) > 0 { + t.Errorf("%d: expected %v, got nil", i, testCase.expected) + } + continue + } + if len(agg.Errors()) != len(testCase.expected) { + t.Errorf("%d: expected %d errors in aggregate, got %v", i, len(testCase.expected), agg) + continue + } + for _, err := range agg.Errors() { + if !testCase.expected[err.Error()] { + t.Errorf("%d: expected %v, got aggregate containing %v", i, testCase.expected, err) + } + } + } +} + +type alwaysMatchingError struct{} + +func (_ alwaysMatchingError) Error() string { + return "error" +} + +func (_ alwaysMatchingError) Is(_ error) bool { + return true +} + +type someError struct{ msg string } + +func (se someError) Error() string { + if se.msg != "" { + return se.msg + } + return "err" +} + +func TestAggregateWithErrorsIs(t *testing.T) { + testCases := []struct { + name string + err error + matchAgainst error + expectMatch bool + }{ + { + name: "no match", + err: aggregate{errors.New("my-error"), errors.New("my-other-error")}, + matchAgainst: fmt.Errorf("no entry %s", "here"), + }, + { + name: "match via .Is()", + err: aggregate{errors.New("forbidden"), alwaysMatchingError{}}, + matchAgainst: errors.New("unauthorized"), + expectMatch: true, + }, + { + name: "match via equality", + err: aggregate{errors.New("err"), someError{}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via nested aggregate", + err: aggregate{errors.New("closed today"), aggregate{aggregate{someError{}}}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via wrapped aggregate", + err: fmt.Errorf("wrap: %w", aggregate{errors.New("err"), someError{}}), + matchAgainst: someError{}, + expectMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := errors.Is(tc.err, tc.matchAgainst) + if result != tc.expectMatch { + t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result) + } + }) + } +} + +type accessTrackingError struct { + wasAccessed bool +} + +func (accessTrackingError) Error() string { + return "err" +} + +func (ate *accessTrackingError) Is(_ error) bool { + ate.wasAccessed = true + return true +} + +var _ error = &accessTrackingError{} + +func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) { + errC := aggregate{&accessTrackingError{}, &accessTrackingError{}} + _ = errors.Is(errC, &accessTrackingError{}) + + var numAccessed int + for _, err := range errC { + if ate := err.(*accessTrackingError); ate.wasAccessed { + numAccessed++ + } + } + if numAccessed != 1 { + t.Errorf("expected exactly one error to get accessed, got %d", numAccessed) + } +} diff --git a/proxy/wireguard/iptables/exec/README.md b/proxy/wireguard/iptables/exec/README.md new file mode 100644 index 000000000000..7944e8dd3be4 --- /dev/null +++ b/proxy/wireguard/iptables/exec/README.md @@ -0,0 +1,5 @@ +# Exec + +This package provides an interface for `os/exec`. It makes it easier to mock +and replace in tests, especially with the [FakeExec](testing/fake_exec.go) +struct. diff --git a/proxy/wireguard/iptables/exec/exec.go b/proxy/wireguard/iptables/exec/exec.go new file mode 100644 index 000000000000..d9c91e3ca3c6 --- /dev/null +++ b/proxy/wireguard/iptables/exec/exec.go @@ -0,0 +1,256 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec + +import ( + "context" + "io" + "io/fs" + osexec "os/exec" + "syscall" + "time" +) + +// ErrExecutableNotFound is returned if the executable is not found. +var ErrExecutableNotFound = osexec.ErrNotFound + +// Interface is an interface that presents a subset of the os/exec API. Use this +// when you want to inject fakeable/mockable exec behavior. +type Interface interface { + // Command returns a Cmd instance which can be used to run a single command. + // This follows the pattern of package os/exec. + Command(cmd string, args ...string) Cmd + + // CommandContext returns a Cmd instance which can be used to run a single command. + // + // The provided context is used to kill the process if the context becomes done + // before the command completes on its own. For example, a timeout can be set in + // the context. + CommandContext(ctx context.Context, cmd string, args ...string) Cmd + + // LookPath wraps os/exec.LookPath + LookPath(file string) (string, error) +} + +// Cmd is an interface that presents an API that is very similar to Cmd from os/exec. +// As more functionality is needed, this can grow. Since Cmd is a struct, we will have +// to replace fields with get/set method pairs. +type Cmd interface { + // Run runs the command to the completion. + Run() error + // CombinedOutput runs the command and returns its combined standard output + // and standard error. This follows the pattern of package os/exec. + CombinedOutput() ([]byte, error) + // Output runs the command and returns standard output, but not standard err + Output() ([]byte, error) + SetDir(dir string) + SetStdin(in io.Reader) + SetStdout(out io.Writer) + SetStderr(out io.Writer) + SetEnv(env []string) + + // StdoutPipe and StderrPipe for getting the process' Stdout and Stderr as + // Readers + StdoutPipe() (io.ReadCloser, error) + StderrPipe() (io.ReadCloser, error) + + // Start and Wait are for running a process non-blocking + Start() error + Wait() error + + // Stops the command by sending SIGTERM. It is not guaranteed the + // process will stop before this function returns. If the process is not + // responding, an internal timer function will send a SIGKILL to force + // terminate after 10 seconds. + Stop() +} + +// ExitError is an interface that presents an API similar to os.ProcessState, which is +// what ExitError from os/exec is. This is designed to make testing a bit easier and +// probably loses some of the cross-platform properties of the underlying library. +type ExitError interface { + String() string + Error() string + Exited() bool + ExitStatus() int +} + +// Implements Interface in terms of really exec()ing. +type executor struct{} + +// New returns a new Interface which will os/exec to run commands. +func New() Interface { + return &executor{} +} + +// Command is part of the Interface interface. +func (executor *executor) Command(cmd string, args ...string) Cmd { + return (*cmdWrapper)(maskErrDotCmd(osexec.Command(cmd, args...))) +} + +// CommandContext is part of the Interface interface. +func (executor *executor) CommandContext(ctx context.Context, cmd string, args ...string) Cmd { + return (*cmdWrapper)(maskErrDotCmd(osexec.CommandContext(ctx, cmd, args...))) +} + +// LookPath is part of the Interface interface +func (executor *executor) LookPath(file string) (string, error) { + path, err := osexec.LookPath(file) + return path, handleError(maskErrDot(err)) +} + +// Wraps exec.Cmd so we can capture errors. +type cmdWrapper osexec.Cmd + +var _ Cmd = &cmdWrapper{} + +func (cmd *cmdWrapper) SetDir(dir string) { + cmd.Dir = dir +} + +func (cmd *cmdWrapper) SetStdin(in io.Reader) { + cmd.Stdin = in +} + +func (cmd *cmdWrapper) SetStdout(out io.Writer) { + cmd.Stdout = out +} + +func (cmd *cmdWrapper) SetStderr(out io.Writer) { + cmd.Stderr = out +} + +func (cmd *cmdWrapper) SetEnv(env []string) { + cmd.Env = env +} + +func (cmd *cmdWrapper) StdoutPipe() (io.ReadCloser, error) { + r, err := (*osexec.Cmd)(cmd).StdoutPipe() + return r, handleError(err) +} + +func (cmd *cmdWrapper) StderrPipe() (io.ReadCloser, error) { + r, err := (*osexec.Cmd)(cmd).StderrPipe() + return r, handleError(err) +} + +func (cmd *cmdWrapper) Start() error { + err := (*osexec.Cmd)(cmd).Start() + return handleError(err) +} + +func (cmd *cmdWrapper) Wait() error { + err := (*osexec.Cmd)(cmd).Wait() + return handleError(err) +} + +// Run is part of the Cmd interface. +func (cmd *cmdWrapper) Run() error { + err := (*osexec.Cmd)(cmd).Run() + return handleError(err) +} + +// CombinedOutput is part of the Cmd interface. +func (cmd *cmdWrapper) CombinedOutput() ([]byte, error) { + out, err := (*osexec.Cmd)(cmd).CombinedOutput() + return out, handleError(err) +} + +func (cmd *cmdWrapper) Output() ([]byte, error) { + out, err := (*osexec.Cmd)(cmd).Output() + return out, handleError(err) +} + +// Stop is part of the Cmd interface. +func (cmd *cmdWrapper) Stop() { + c := (*osexec.Cmd)(cmd) + + if c.Process == nil { + return + } + + c.Process.Signal(syscall.SIGTERM) + + time.AfterFunc(10*time.Second, func() { + if !c.ProcessState.Exited() { + c.Process.Signal(syscall.SIGKILL) + } + }) +} + +func handleError(err error) error { + if err == nil { + return nil + } + + switch e := err.(type) { + case *osexec.ExitError: + return &ExitErrorWrapper{e} + case *fs.PathError: + return ErrExecutableNotFound + case *osexec.Error: + if e.Err == osexec.ErrNotFound { + return ErrExecutableNotFound + } + } + + return err +} + +// ExitErrorWrapper is an implementation of ExitError in terms of os/exec ExitError. +// Note: standard exec.ExitError is type *os.ProcessState, which already implements Exited(). +type ExitErrorWrapper struct { + *osexec.ExitError +} + +var _ ExitError = &ExitErrorWrapper{} + +// ExitStatus is part of the ExitError interface. +func (eew ExitErrorWrapper) ExitStatus() int { + ws, ok := eew.Sys().(syscall.WaitStatus) + if !ok { + panic("can't call ExitStatus() on a non-WaitStatus exitErrorWrapper") + } + return ws.ExitStatus() +} + +// CodeExitError is an implementation of ExitError consisting of an error object +// and an exit code (the upper bits of os.exec.ExitStatus). +type CodeExitError struct { + Err error + Code int +} + +var _ ExitError = CodeExitError{} + +func (e CodeExitError) Error() string { + return e.Err.Error() +} + +func (e CodeExitError) String() string { + return e.Err.Error() +} + +// Exited is to check if the process has finished +func (e CodeExitError) Exited() bool { + return true +} + +// ExitStatus is for checking the error code +func (e CodeExitError) ExitStatus() int { + return e.Code +} diff --git a/proxy/wireguard/iptables/exec/fixup_go118.go b/proxy/wireguard/iptables/exec/fixup_go118.go new file mode 100644 index 000000000000..acf45f1cd5b4 --- /dev/null +++ b/proxy/wireguard/iptables/exec/fixup_go118.go @@ -0,0 +1,32 @@ +//go:build !go1.19 +// +build !go1.19 + +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec + +import ( + osexec "os/exec" +) + +func maskErrDotCmd(cmd *osexec.Cmd) *osexec.Cmd { + return cmd +} + +func maskErrDot(err error) error { + return err +} diff --git a/proxy/wireguard/iptables/exec/fixup_go119.go b/proxy/wireguard/iptables/exec/fixup_go119.go new file mode 100644 index 000000000000..55874c9297e3 --- /dev/null +++ b/proxy/wireguard/iptables/exec/fixup_go119.go @@ -0,0 +1,40 @@ +//go:build go1.19 +// +build go1.19 + +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec + +import ( + "errors" + osexec "os/exec" +) + +// maskErrDotCmd reverts the behavior of osexec.Cmd to what it was before go1.19 +// specifically set the Err field to nil (LookPath returns a new error when the file +// is resolved to the current directory. +func maskErrDotCmd(cmd *osexec.Cmd) *osexec.Cmd { + cmd.Err = maskErrDot(cmd.Err) + return cmd +} + +func maskErrDot(err error) error { + if err != nil && errors.Is(err, osexec.ErrDot) { + return nil + } + return err +} diff --git a/proxy/wireguard/iptables/iptables.go b/proxy/wireguard/iptables/iptables.go new file mode 100644 index 000000000000..26e4c7f2e837 --- /dev/null +++ b/proxy/wireguard/iptables/iptables.go @@ -0,0 +1,874 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "bufio" + "bytes" + "context" + "fmt" + "regexp" + "strconv" + "strings" + "sync" + "time" + + utilexec "github.com/xtls/xray-core/proxy/wireguard/iptables/exec" + "github.com/xtls/xray-core/proxy/wireguard/iptables/sets" + utilversion "github.com/xtls/xray-core/proxy/wireguard/iptables/version" + utilwait "github.com/xtls/xray-core/proxy/wireguard/iptables/wait" + // "k8s.io/klog/v2" + // utiltrace "k8s.io/utils/trace" +) + +// RulePosition holds the -I/-A flags for iptable +type RulePosition string + +const ( + // Prepend is the insert flag for iptable + Prepend RulePosition = "-I" + // Append is the append flag for iptable + Append RulePosition = "-A" +) + +// Interface is an injectable interface for running iptables commands. Implementations must be goroutine-safe. +type Interface interface { + // EnsureChain checks if the specified chain exists and, if not, creates it. If the chain existed, return true. + EnsureChain(table Table, chain Chain) (bool, error) + // FlushChain clears the specified chain. If the chain did not exist, return error. + FlushChain(table Table, chain Chain) error + // DeleteChain deletes the specified chain. If the chain did not exist, return error. + DeleteChain(table Table, chain Chain) error + // ChainExists tests whether the specified chain exists, returning an error if it + // does not, or if it is unable to check. + ChainExists(table Table, chain Chain) (bool, error) + // EnsureRule checks if the specified rule is present and, if not, creates it. If the rule existed, return true. + EnsureRule(position RulePosition, table Table, chain Chain, args ...string) (bool, error) + // DeleteRule checks if the specified rule is present and, if so, deletes it. + DeleteRule(table Table, chain Chain, args ...string) error + // IsIPv6 returns true if this is managing ipv6 tables. + IsIPv6() bool + // Protocol returns the IP family this instance is managing, + Protocol() Protocol + // SaveInto calls `iptables-save` for table and stores result in a given buffer. + SaveInto(table Table, buffer *bytes.Buffer) error + // Restore runs `iptables-restore` passing data through []byte. + // table is the Table to restore + // data should be formatted like the output of SaveInto() + // flush sets the presence of the "--noflush" flag. see: FlushFlag + // counters sets the "--counters" flag. see: RestoreCountersFlag + Restore(table Table, data []byte, flush FlushFlag, counters RestoreCountersFlag) error + // RestoreAll is the same as Restore except that no table is specified. + RestoreAll(data []byte, flush FlushFlag, counters RestoreCountersFlag) error + // Monitor detects when the given iptables tables have been flushed by an external + // tool (e.g. a firewall reload) by creating canary chains and polling to see if + // they have been deleted. (Specifically, it polls tables[0] every interval until + // the canary has been deleted from there, then waits a short additional time for + // the canaries to be deleted from the remaining tables as well. You can optimize + // the polling by listing a relatively empty table in tables[0]). When a flush is + // detected, this calls the reloadFunc so the caller can reload their own iptables + // rules. If it is unable to create the canary chains (either initially or after + // a reload) it will log an error and stop monitoring. + // (This function should be called from a goroutine.) + Monitor(canary Chain, tables []Table, reloadFunc func(), interval time.Duration, stopCh <-chan struct{}) + // HasRandomFully reveals whether `-j MASQUERADE` takes the + // `--random-fully` option. This is helpful to work around a + // Linux kernel bug that sometimes causes multiple flows to get + // mapped to the same IP:PORT and consequently some suffer packet + // drops. + HasRandomFully() bool + + // Present checks if the kernel supports the iptable interface + Present() bool +} + +// Protocol defines the ip protocol either ipv4 or ipv6 +type Protocol string + +const ( + // ProtocolIPv4 represents ipv4 protocol in iptables + ProtocolIPv4 Protocol = "IPv4" + // ProtocolIPv6 represents ipv6 protocol in iptables + ProtocolIPv6 Protocol = "IPv6" +) + +// Table represents different iptable like filter,nat, mangle and raw +type Table string + +const ( + // TableNAT represents the built-in nat table + TableNAT Table = "nat" + // TableFilter represents the built-in filter table + TableFilter Table = "filter" + // TableMangle represents the built-in mangle table + TableMangle Table = "mangle" +) + +// Chain represents the different rules +type Chain string + +const ( + // ChainPostrouting used for source NAT in nat table + ChainPostrouting Chain = "POSTROUTING" + // ChainPrerouting used for DNAT (destination NAT) in nat table + ChainPrerouting Chain = "PREROUTING" + // ChainOutput used for the packets going out from local + ChainOutput Chain = "OUTPUT" + // ChainInput used for incoming packets + ChainInput Chain = "INPUT" + // ChainForward used for the packets for another NIC + ChainForward Chain = "FORWARD" +) + +const ( + cmdIPTablesSave string = "iptables-save" + cmdIPTablesRestore string = "iptables-restore" + cmdIPTables string = "iptables" + cmdIP6TablesRestore string = "ip6tables-restore" + cmdIP6TablesSave string = "ip6tables-save" + cmdIP6Tables string = "ip6tables" +) + +// RestoreCountersFlag is an option flag for Restore +type RestoreCountersFlag bool + +// RestoreCounters a boolean true constant for the option flag RestoreCountersFlag +const RestoreCounters RestoreCountersFlag = true + +// NoRestoreCounters a boolean false constant for the option flag RestoreCountersFlag +const NoRestoreCounters RestoreCountersFlag = false + +// FlushFlag an option flag for Flush +type FlushFlag bool + +// FlushTables a boolean true constant for option flag FlushFlag +const FlushTables FlushFlag = true + +// NoFlushTables a boolean false constant for option flag FlushFlag +const NoFlushTables FlushFlag = false + +// MinCheckVersion minimum version to be checked +// Versions of iptables less than this do not support the -C / --check flag +// (test whether a rule exists). +var MinCheckVersion = utilversion.MustParseGeneric("1.4.11") + +// RandomFullyMinVersion is the minimum version from which the --random-fully flag is supported, +// used for port mapping to be fully randomized +var RandomFullyMinVersion = utilversion.MustParseGeneric("1.6.2") + +// WaitMinVersion a minimum iptables versions supporting the -w and -w flags +var WaitMinVersion = utilversion.MustParseGeneric("1.4.20") + +// WaitIntervalMinVersion a minimum iptables versions supporting the wait interval useconds +var WaitIntervalMinVersion = utilversion.MustParseGeneric("1.6.1") + +// WaitSecondsMinVersion a minimum iptables versions supporting the wait seconds +var WaitSecondsMinVersion = utilversion.MustParseGeneric("1.4.22") + +// WaitRestoreMinVersion a minimum iptables versions supporting the wait restore seconds +var WaitRestoreMinVersion = utilversion.MustParseGeneric("1.6.2") + +// WaitString a constant for specifying the wait flag +const WaitString = "-w" + +// WaitSecondsValue a constant for specifying the default wait seconds +const WaitSecondsValue = "5" + +// WaitIntervalString a constant for specifying the wait interval flag +const WaitIntervalString = "-W" + +// WaitIntervalUsecondsValue a constant for specifying the default wait interval useconds +const WaitIntervalUsecondsValue = "100000" + +// LockfilePath16x is the iptables 1.6.x lock file acquired by any process that's making any change in the iptable rule +const LockfilePath16x = "/run/xtables.lock" + +// LockfilePath14x is the iptables 1.4.x lock file acquired by any process that's making any change in the iptable rule +const LockfilePath14x = "@xtables" + +// runner implements Interface in terms of exec("iptables"). +type runner struct { + mu sync.Mutex + exec utilexec.Interface + protocol Protocol + hasCheck bool + hasRandomFully bool + waitFlag []string + restoreWaitFlag []string + lockfilePath14x string + lockfilePath16x string +} + +// newInternal returns a new Interface which will exec iptables, and allows the +// caller to change the iptables-restore lockfile path +func newInternal(exec utilexec.Interface, protocol Protocol, lockfilePath14x, lockfilePath16x string) Interface { + version, err := getIPTablesVersion(exec, protocol) + if err != nil { + // klog.InfoS("Error checking iptables version, assuming version at least", "version", MinCheckVersion, "err", err) + version = MinCheckVersion + } + + if lockfilePath16x == "" { + lockfilePath16x = LockfilePath16x + } + if lockfilePath14x == "" { + lockfilePath14x = LockfilePath14x + } + + runner := &runner{ + exec: exec, + protocol: protocol, + hasCheck: version.AtLeast(MinCheckVersion), + hasRandomFully: version.AtLeast(RandomFullyMinVersion), + waitFlag: getIPTablesWaitFlag(version), + restoreWaitFlag: getIPTablesRestoreWaitFlag(version, exec, protocol), + lockfilePath14x: lockfilePath14x, + lockfilePath16x: lockfilePath16x, + } + return runner +} + +// New returns a new Interface which will exec iptables. +func New(exec utilexec.Interface, protocol Protocol) Interface { + return newInternal(exec, protocol, "", "") +} + +// EnsureChain is part of Interface. +func (runner *runner) EnsureChain(table Table, chain Chain) (bool, error) { + fullArgs := makeFullArgs(table, chain) + + runner.mu.Lock() + defer runner.mu.Unlock() + + out, err := runner.run(opCreateChain, fullArgs) + if err != nil { + if ee, ok := err.(utilexec.ExitError); ok { + if ee.Exited() && ee.ExitStatus() == 1 { + return true, nil + } + } + return false, fmt.Errorf("error creating chain %q: %v: %s", chain, err, out) + } + return false, nil +} + +// FlushChain is part of Interface. +func (runner *runner) FlushChain(table Table, chain Chain) error { + fullArgs := makeFullArgs(table, chain) + + runner.mu.Lock() + defer runner.mu.Unlock() + + out, err := runner.run(opFlushChain, fullArgs) + if err != nil { + return fmt.Errorf("error flushing chain %q: %v: %s", chain, err, out) + } + return nil +} + +// DeleteChain is part of Interface. +func (runner *runner) DeleteChain(table Table, chain Chain) error { + fullArgs := makeFullArgs(table, chain) + + runner.mu.Lock() + defer runner.mu.Unlock() + + out, err := runner.run(opDeleteChain, fullArgs) + if err != nil { + return fmt.Errorf("error deleting chain %q: %v: %s", chain, err, out) + } + return nil +} + +// EnsureRule is part of Interface. +func (runner *runner) EnsureRule(position RulePosition, table Table, chain Chain, args ...string) (bool, error) { + fullArgs := makeFullArgs(table, chain, args...) + + runner.mu.Lock() + defer runner.mu.Unlock() + + exists, err := runner.checkRule(table, chain, args...) + if err != nil { + return false, err + } + if exists { + return true, nil + } + out, err := runner.run(operation(position), fullArgs) + if err != nil { + return false, fmt.Errorf("error appending rule: %v: %s", err, out) + } + return false, nil +} + +// DeleteRule is part of Interface. +func (runner *runner) DeleteRule(table Table, chain Chain, args ...string) error { + fullArgs := makeFullArgs(table, chain, args...) + + runner.mu.Lock() + defer runner.mu.Unlock() + + exists, err := runner.checkRule(table, chain, args...) + if err != nil { + return err + } + if !exists { + return nil + } + out, err := runner.run(opDeleteRule, fullArgs) + if err != nil { + return fmt.Errorf("error deleting rule: %v: %s", err, out) + } + return nil +} + +func (runner *runner) IsIPv6() bool { + return runner.protocol == ProtocolIPv6 +} + +func (runner *runner) Protocol() Protocol { + return runner.protocol +} + +// SaveInto is part of Interface. +func (runner *runner) SaveInto(table Table, buffer *bytes.Buffer) error { + runner.mu.Lock() + defer runner.mu.Unlock() + + // trace := utiltrace.New("iptables save") + // defer trace.LogIfLong(2 * time.Second) + + // run and return + iptablesSaveCmd := iptablesSaveCommand(runner.protocol) + args := []string{"-t", string(table)} + // klog.V(4).InfoS("Running", "command", iptablesSaveCmd, "arguments", args) + cmd := runner.exec.Command(iptablesSaveCmd, args...) + cmd.SetStdout(buffer) + stderrBuffer := bytes.NewBuffer(nil) + cmd.SetStderr(stderrBuffer) + + err := cmd.Run() + if err != nil { + stderrBuffer.WriteTo(buffer) // ignore error, since we need to return the original error + } + return err +} + +// Restore is part of Interface. +func (runner *runner) Restore(table Table, data []byte, flush FlushFlag, counters RestoreCountersFlag) error { + // setup args + args := []string{"-T", string(table)} + return runner.restoreInternal(args, data, flush, counters) +} + +// RestoreAll is part of Interface. +func (runner *runner) RestoreAll(data []byte, flush FlushFlag, counters RestoreCountersFlag) error { + // setup args + args := make([]string, 0) + return runner.restoreInternal(args, data, flush, counters) +} + +type iptablesLocker interface { + Close() error +} + +// restoreInternal is the shared part of Restore/RestoreAll +func (runner *runner) restoreInternal(args []string, data []byte, flush FlushFlag, counters RestoreCountersFlag) error { + runner.mu.Lock() + defer runner.mu.Unlock() + + // trace := utiltrace.New("iptables restore") + // defer trace.LogIfLong(2 * time.Second) + + if !flush { + args = append(args, "--noflush") + } + if counters { + args = append(args, "--counters") + } + + // Grab the iptables lock to prevent iptables-restore and iptables + // from stepping on each other. iptables-restore 1.6.2 will have + // a --wait option like iptables itself, but that's not widely deployed. + if len(runner.restoreWaitFlag) == 0 { + locker, err := grabIptablesLocks(runner.lockfilePath14x, runner.lockfilePath16x) + if err != nil { + return err + } + // trace.Step("Locks grabbed") + defer func(locker iptablesLocker) { + if err := locker.Close(); err != nil { + // klog.ErrorS(err, "Failed to close iptables locks") + } + }(locker) + } + + // run the command and return the output or an error including the output and error + fullArgs := append(runner.restoreWaitFlag, args...) + iptablesRestoreCmd := iptablesRestoreCommand(runner.protocol) + // klog.V(4).InfoS("Running", "command", iptablesRestoreCmd, "arguments", fullArgs) + cmd := runner.exec.Command(iptablesRestoreCmd, fullArgs...) + cmd.SetStdin(bytes.NewBuffer(data)) + b, err := cmd.CombinedOutput() + if err != nil { + pErr, ok := parseRestoreError(string(b)) + if ok { + return pErr + } + return fmt.Errorf("%w: %s", err, b) + } + return nil +} + +func iptablesSaveCommand(protocol Protocol) string { + if protocol == ProtocolIPv6 { + return cmdIP6TablesSave + } + return cmdIPTablesSave +} + +func iptablesRestoreCommand(protocol Protocol) string { + if protocol == ProtocolIPv6 { + return cmdIP6TablesRestore + } + return cmdIPTablesRestore +} + +func iptablesCommand(protocol Protocol) string { + if protocol == ProtocolIPv6 { + return cmdIP6Tables + } + return cmdIPTables +} + +func (runner *runner) run(op operation, args []string) ([]byte, error) { + return runner.runContext(context.TODO(), op, args) +} + +func (runner *runner) runContext(ctx context.Context, op operation, args []string) ([]byte, error) { + iptablesCmd := iptablesCommand(runner.protocol) + fullArgs := append(runner.waitFlag, string(op)) + fullArgs = append(fullArgs, args...) + // klog.V(5).InfoS("Running", "command", iptablesCmd, "arguments", fullArgs) + if ctx == nil { + return runner.exec.Command(iptablesCmd, fullArgs...).CombinedOutput() + } + return runner.exec.CommandContext(ctx, iptablesCmd, fullArgs...).CombinedOutput() + // Don't log err here - callers might not think it is an error. +} + +// Returns (bool, nil) if it was able to check the existence of the rule, or +// (, error) if the process of checking failed. +func (runner *runner) checkRule(table Table, chain Chain, args ...string) (bool, error) { + if runner.hasCheck { + return runner.checkRuleUsingCheck(makeFullArgs(table, chain, args...)) + } + return runner.checkRuleWithoutCheck(table, chain, args...) +} + +var hexnumRE = regexp.MustCompile("0x0+([0-9])") + +func trimhex(s string) string { + return hexnumRE.ReplaceAllString(s, "0x$1") +} + +// Executes the rule check without using the "-C" flag, instead parsing iptables-save. +// Present for compatibility with <1.4.11 versions of iptables. This is full +// of hack and half-measures. We should nix this ASAP. +func (runner *runner) checkRuleWithoutCheck(table Table, chain Chain, args ...string) (bool, error) { + iptablesSaveCmd := iptablesSaveCommand(runner.protocol) + // klog.V(1).InfoS("Running", "command", iptablesSaveCmd, "table", string(table)) + out, err := runner.exec.Command(iptablesSaveCmd, "-t", string(table)).CombinedOutput() + if err != nil { + return false, fmt.Errorf("error checking rule: %v", err) + } + + // Sadly, iptables has inconsistent quoting rules for comments. Just remove all quotes. + // Also, quoted multi-word comments (which are counted as a single arg) + // will be unpacked into multiple args, + // in order to compare against iptables-save output (which will be split at whitespace boundary) + // e.g. a single arg('"this must be before the NodePort rules"') will be unquoted and unpacked into 7 args. + var argsCopy []string + for i := range args { + tmpField := strings.Trim(args[i], "\"") + tmpField = trimhex(tmpField) + argsCopy = append(argsCopy, strings.Fields(tmpField)...) + } + argset := sets.NewString(argsCopy...) + + for _, line := range strings.Split(string(out), "\n") { + fields := strings.Fields(line) + + // Check that this is a rule for the correct chain, and that it has + // the correct number of argument (+2 for "-A ") + if !strings.HasPrefix(line, fmt.Sprintf("-A %s", string(chain))) || len(fields) != len(argsCopy)+2 { + continue + } + + // Sadly, iptables has inconsistent quoting rules for comments. + // Just remove all quotes. + for i := range fields { + fields[i] = strings.Trim(fields[i], "\"") + fields[i] = trimhex(fields[i]) + } + + // TODO: This misses reorderings e.g. "-x foo ! -y bar" will match "! -x foo -y bar" + if sets.NewString(fields...).IsSuperset(argset) { + return true, nil + } + // klog.V(5).InfoS("DBG: fields is not a superset of args", "fields", fields, "arguments", args) + } + + return false, nil +} + +// Executes the rule check using the "-C" flag +func (runner *runner) checkRuleUsingCheck(args []string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + out, err := runner.runContext(ctx, opCheckRule, args) + if ctx.Err() == context.DeadlineExceeded { + return false, fmt.Errorf("timed out while checking rules") + } + if err == nil { + return true, nil + } + if ee, ok := err.(utilexec.ExitError); ok { + // iptables uses exit(1) to indicate a failure of the operation, + // as compared to a malformed commandline, for example. + if ee.Exited() && ee.ExitStatus() == 1 { + return false, nil + } + } + return false, fmt.Errorf("error checking rule: %v: %s", err, out) +} + +const ( + // Max time we wait for an iptables flush to complete after we notice it has started + iptablesFlushTimeout = 5 * time.Second + // How often we poll while waiting for an iptables flush to complete + iptablesFlushPollTime = 100 * time.Millisecond +) + +// Monitor is part of Interface +func (runner *runner) Monitor(canary Chain, tables []Table, reloadFunc func(), interval time.Duration, stopCh <-chan struct{}) { + for { + _ = utilwait.PollImmediateUntil(interval, func() (bool, error) { + for _, table := range tables { + if _, err := runner.EnsureChain(table, canary); err != nil { + // klog.ErrorS(err, "Could not set up iptables canary", "table", table, "chain", canary) + return false, nil + } + } + return true, nil + }, stopCh) + + // Poll until stopCh is closed or iptables is flushed + err := utilwait.PollUntil(interval, func() (bool, error) { + if exists, err := runner.ChainExists(tables[0], canary); exists { + return false, nil + } else if isResourceError(err) { + // klog.ErrorS(err, "Could not check for iptables canary", "table", tables[0], "chain", canary) + return false, nil + } + // klog.V(2).InfoS("IPTables canary deleted", "table", tables[0], "chain", canary) + // Wait for the other canaries to be deleted too before returning + // so we don't start reloading too soon. + err := utilwait.PollImmediate(iptablesFlushPollTime, iptablesFlushTimeout, func() (bool, error) { + for i := 1; i < len(tables); i++ { + if exists, err := runner.ChainExists(tables[i], canary); exists || isResourceError(err) { + return false, nil + } + } + return true, nil + }) + if err != nil { + // klog.InfoS("Inconsistent iptables state detected") + } + return true, nil + }, stopCh) + if err != nil { + // stopCh was closed + for _, table := range tables { + _ = runner.DeleteChain(table, canary) + } + return + } + + // klog.V(2).InfoS("Reloading after iptables flush") + reloadFunc() + } +} + +// ChainExists is part of Interface +func (runner *runner) ChainExists(table Table, chain Chain) (bool, error) { + fullArgs := makeFullArgs(table, chain) + + runner.mu.Lock() + defer runner.mu.Unlock() + + // trace := utiltrace.New("iptables ChainExists") + // defer trace.LogIfLong(2 * time.Second) + + _, err := runner.run(opListChain, fullArgs) + return err == nil, err +} + +type operation string + +const ( + opCreateChain operation = "-N" + opFlushChain operation = "-F" + opDeleteChain operation = "-X" + opListChain operation = "-S" + opCheckRule operation = "-C" + opDeleteRule operation = "-D" +) + +func makeFullArgs(table Table, chain Chain, args ...string) []string { + return append([]string{string(chain), "-t", string(table)}, args...) +} + +const iptablesVersionPattern = `v([0-9]+(\.[0-9]+)+)` + +// getIPTablesVersion runs "iptables --version" and parses the returned version +func getIPTablesVersion(exec utilexec.Interface, protocol Protocol) (*utilversion.Version, error) { + // this doesn't access mutable state so we don't need to use the interface / runner + iptablesCmd := iptablesCommand(protocol) + bytes, err := exec.Command(iptablesCmd, "--version").CombinedOutput() + if err != nil { + return nil, err + } + versionMatcher := regexp.MustCompile(iptablesVersionPattern) + match := versionMatcher.FindStringSubmatch(string(bytes)) + if match == nil { + return nil, fmt.Errorf("no iptables version found in string: %s", bytes) + } + version, err := utilversion.ParseGeneric(match[1]) + if err != nil { + return nil, fmt.Errorf("iptables version %q is not a valid version string: %v", match[1], err) + } + + return version, nil +} + +// Checks if iptables version has a "wait" flag +func getIPTablesWaitFlag(version *utilversion.Version) []string { + switch { + case version.AtLeast(WaitIntervalMinVersion): + return []string{WaitString, WaitSecondsValue, WaitIntervalString, WaitIntervalUsecondsValue} + case version.AtLeast(WaitSecondsMinVersion): + return []string{WaitString, WaitSecondsValue} + case version.AtLeast(WaitMinVersion): + return []string{WaitString} + default: + return nil + } +} + +// Checks if iptables-restore has a "wait" flag +func getIPTablesRestoreWaitFlag(version *utilversion.Version, exec utilexec.Interface, protocol Protocol) []string { + if version.AtLeast(WaitRestoreMinVersion) { + return []string{WaitString, WaitSecondsValue, WaitIntervalString, WaitIntervalUsecondsValue} + } + + // Older versions may have backported features; if iptables-restore supports + // --version, assume it also supports --wait + vstring, err := getIPTablesRestoreVersionString(exec, protocol) + if err != nil || vstring == "" { + // klog.V(3).InfoS("Couldn't get iptables-restore version; assuming it doesn't support --wait") + return nil + } + if _, err := utilversion.ParseGeneric(vstring); err != nil { + // klog.V(3).InfoS("Couldn't parse iptables-restore version; assuming it doesn't support --wait") + return nil + } + return []string{WaitString} +} + +// getIPTablesRestoreVersionString runs "iptables-restore --version" to get the version string +// in the form "X.X.X" +func getIPTablesRestoreVersionString(exec utilexec.Interface, protocol Protocol) (string, error) { + // this doesn't access mutable state so we don't need to use the interface / runner + + // iptables-restore hasn't always had --version, and worse complains + // about unrecognized commands but doesn't exit when it gets them. + // Work around that by setting stdin to nothing so it exits immediately. + iptablesRestoreCmd := iptablesRestoreCommand(protocol) + cmd := exec.Command(iptablesRestoreCmd, "--version") + cmd.SetStdin(bytes.NewReader([]byte{})) + bytes, err := cmd.CombinedOutput() + if err != nil { + return "", err + } + versionMatcher := regexp.MustCompile(iptablesVersionPattern) + match := versionMatcher.FindStringSubmatch(string(bytes)) + if match == nil { + return "", fmt.Errorf("no iptables version found in string: %s", bytes) + } + return match[1], nil +} + +func (runner *runner) HasRandomFully() bool { + return runner.hasRandomFully +} + +// Present tests if iptable is supported on current kernel by checking the existence +// of default table and chain +func (runner *runner) Present() bool { + if _, err := runner.ChainExists(TableNAT, ChainPostrouting); err != nil { + return false + } + + return true +} + +var iptablesNotFoundStrings = []string{ + // iptables-legacy [-A|-I] BAD-CHAIN [...] + // iptables-legacy [-C|-D] GOOD-CHAIN [...non-matching rule...] + // iptables-legacy [-X|-F|-Z] BAD-CHAIN + // iptables-nft -X BAD-CHAIN + // NB: iptables-nft [-F|-Z] BAD-CHAIN exits with no error + "No chain/target/match by that name", + + // iptables-legacy [...] -j BAD-CHAIN + // iptables-nft-1.8.0 [-A|-I] BAD-CHAIN [...] + // iptables-nft-1.8.0 [-A|-I] GOOD-CHAIN -j BAD-CHAIN + // NB: also matches some other things like "-m BAD-MODULE" + "No such file or directory", + + // iptables-legacy [-C|-D] BAD-CHAIN [...] + // iptables-nft [-C|-D] GOOD-CHAIN [...non-matching rule...] + "does a matching rule exist", + + // iptables-nft-1.8.2 [-A|-C|-D|-I] BAD-CHAIN [...] + // iptables-nft-1.8.2 [...] -j BAD-CHAIN + "does not exist", +} + +// IsNotFoundError returns true if the error indicates "not found". It parses +// the error string looking for known values, which is imperfect; beware using +// this function for anything beyond deciding between logging or ignoring an +// error. +func IsNotFoundError(err error) bool { + es := err.Error() + for _, str := range iptablesNotFoundStrings { + if strings.Contains(es, str) { + return true + } + } + return false +} + +const iptablesStatusResourceProblem = 4 + +// isResourceError returns true if the error indicates that iptables ran into a "resource +// problem" and was unable to attempt the request. In particular, this will be true if it +// times out trying to get the iptables lock. +func isResourceError(err error) bool { + if ee, isExitError := err.(utilexec.ExitError); isExitError { + return ee.ExitStatus() == iptablesStatusResourceProblem + } + return false +} + +// ParseError records the payload when iptables reports an error parsing its input. +type ParseError interface { + // Line returns the line number on which the parse error was reported. + // NOTE: First line is 1. + Line() int + // Error returns the error message of the parse error, including line number. + Error() string +} + +type parseError struct { + cmd string + line int +} + +func (e parseError) Line() int { + return e.line +} + +func (e parseError) Error() string { + return fmt.Sprintf("%s: input error on line %d: ", e.cmd, e.line) +} + +// LineData represents a single numbered line of data. +type LineData struct { + // Line holds the line number (the first line is 1). + Line int + // The data of the line. + Data string +} + +var regexpParseError = regexp.MustCompile("line ([1-9][0-9]*) failed$") + +// parseRestoreError extracts the line from the error, if it matches returns parseError +// for example: +// input: iptables-restore: line 51 failed +// output: parseError: cmd = iptables-restore, line = 51 +// NOTE: parseRestoreError depends on the error format of iptables, if it ever changes +// we need to update this function +func parseRestoreError(str string) (ParseError, bool) { + errors := strings.Split(str, ":") + if len(errors) != 2 { + return nil, false + } + cmd := errors[0] + matches := regexpParseError.FindStringSubmatch(errors[1]) + if len(matches) != 2 { + return nil, false + } + line, errMsg := strconv.Atoi(matches[1]) + if errMsg != nil { + return nil, false + } + return parseError{cmd: cmd, line: line}, true +} + +// ExtractLines extracts the -count and +count data from the lineNum row of lines and return +// NOTE: lines start from line 1 +func ExtractLines(lines []byte, line, count int) []LineData { + // first line is line 1, so line can't be smaller than 1 + if line < 1 { + return nil + } + start := line - count + if start <= 0 { + start = 1 + } + end := line + count + 1 + + offset := 1 + scanner := bufio.NewScanner(bytes.NewBuffer(lines)) + extractLines := make([]LineData, 0, count*2) + for scanner.Scan() { + if offset >= start && offset < end { + extractLines = append(extractLines, LineData{ + Line: offset, + Data: scanner.Text(), + }) + } + if offset == end { + break + } + offset++ + } + return extractLines +} diff --git a/proxy/wireguard/iptables/iptables_linux.go b/proxy/wireguard/iptables/iptables_linux.go new file mode 100644 index 000000000000..a4a2b25aa42d --- /dev/null +++ b/proxy/wireguard/iptables/iptables_linux.go @@ -0,0 +1,101 @@ +//go:build linux +// +build linux + +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "fmt" + "net" + "os" + "time" + + utilerrors "github.com/xtls/xray-core/proxy/wireguard/iptables/errors" + "github.com/xtls/xray-core/proxy/wireguard/iptables/wait" + "golang.org/x/sys/unix" +) + +type locker struct { + lock16 *os.File + lock14 *net.UnixListener +} + +func (l *locker) Close() error { + errList := []error{} + if l.lock16 != nil { + if err := l.lock16.Close(); err != nil { + errList = append(errList, err) + } + } + if l.lock14 != nil { + if err := l.lock14.Close(); err != nil { + errList = append(errList, err) + } + } + return utilerrors.NewAggregate(errList) +} + +func grabIptablesLocks(lockfilePath14x, lockfilePath16x string) (iptablesLocker, error) { + var err error + var success bool + + l := &locker{} + defer func(l *locker) { + // Clean up immediately on failure + if !success { + l.Close() + } + }(l) + + // Grab both 1.6.x and 1.4.x-style locks; we don't know what the + // iptables-restore version is if it doesn't support --wait, so we + // can't assume which lock method it'll use. + + // Roughly duplicate iptables 1.6.x xtables_lock() function. + l.lock16, err = os.OpenFile(lockfilePath16x, os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("failed to open iptables lock %s: %v", lockfilePath16x, err) + } + + if err := wait.PollImmediate(200*time.Millisecond, 2*time.Second, func() (bool, error) { + if err := grabIptablesFileLock(l.lock16); err != nil { + return false, nil + } + return true, nil + }); err != nil { + return nil, fmt.Errorf("failed to acquire new iptables lock: %v", err) + } + + // Roughly duplicate iptables 1.4.x xtables_lock() function. + if err := wait.PollImmediate(200*time.Millisecond, 2*time.Second, func() (bool, error) { + l.lock14, err = net.ListenUnix("unix", &net.UnixAddr{Name: lockfilePath14x, Net: "unix"}) + if err != nil { + return false, nil + } + return true, nil + }); err != nil { + return nil, fmt.Errorf("failed to acquire old iptables lock: %v", err) + } + + success = true + return l, nil +} + +func grabIptablesFileLock(f *os.File) error { + return unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB) +} diff --git a/proxy/wireguard/iptables/iptables_unsupported.go b/proxy/wireguard/iptables/iptables_unsupported.go new file mode 100644 index 000000000000..0c7c5ee3e651 --- /dev/null +++ b/proxy/wireguard/iptables/iptables_unsupported.go @@ -0,0 +1,33 @@ +//go:build !linux +// +build !linux + +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "fmt" + "os" +) + +func grabIptablesLocks(lock14filePath, lock16filePath string) (iptablesLocker, error) { + return nil, fmt.Errorf("iptables unsupported on this platform") +} + +func grabIptablesFileLock(f *os.File) error { + return fmt.Errorf("iptables unsupported on this platform") +} diff --git a/proxy/wireguard/iptables/save_restore.go b/proxy/wireguard/iptables/save_restore.go new file mode 100644 index 000000000000..b788beb91135 --- /dev/null +++ b/proxy/wireguard/iptables/save_restore.go @@ -0,0 +1,52 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "bytes" + "fmt" +) + +// MakeChainLine return an iptables-save/restore formatted chain line given a Chain +func MakeChainLine(chain Chain) string { + return fmt.Sprintf(":%s - [0:0]", chain) +} + +// GetChainsFromTable parses iptables-save data to find the chains that are defined. It +// assumes that save contains a single table's data, and returns a map with keys for every +// chain defined in that table. +func GetChainsFromTable(save []byte) map[Chain]struct{} { + chainsMap := make(map[Chain]struct{}) + + for { + i := bytes.Index(save, []byte("\n:")) + if i == -1 { + break + } + start := i + 2 + save = save[start:] + end := bytes.Index(save, []byte(" ")) + if i == -1 { + // shouldn't happen, but... + break + } + chain := Chain(save[:end]) + chainsMap[chain] = struct{}{} + save = save[end:] + } + return chainsMap +} diff --git a/proxy/wireguard/iptables/sets/byte.go b/proxy/wireguard/iptables/sets/byte.go new file mode 100644 index 000000000000..4d7a17c3afad --- /dev/null +++ b/proxy/wireguard/iptables/sets/byte.go @@ -0,0 +1,137 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// Byte is a set of bytes, implemented via map[byte]struct{} for minimal memory consumption. +// +// Deprecated: use generic Set instead. +// new ways: +// s1 := Set[byte]{} +// s2 := New[byte]() +type Byte map[byte]Empty + +// NewByte creates a Byte from a list of values. +func NewByte(items ...byte) Byte { + return Byte(New[byte](items...)) +} + +// ByteKeySet creates a Byte from a keys of a map[byte](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func ByteKeySet[T any](theMap map[byte]T) Byte { + return Byte(KeySet(theMap)) +} + +// Insert adds items to the set. +func (s Byte) Insert(items ...byte) Byte { + return Byte(cast(s).Insert(items...)) +} + +// Delete removes all items from the set. +func (s Byte) Delete(items ...byte) Byte { + return Byte(cast(s).Delete(items...)) +} + +// Has returns true if and only if item is contained in the set. +func (s Byte) Has(item byte) bool { + return cast(s).Has(item) +} + +// HasAll returns true if and only if all items are contained in the set. +func (s Byte) HasAll(items ...byte) bool { + return cast(s).HasAll(items...) +} + +// HasAny returns true if any items are contained in the set. +func (s Byte) HasAny(items ...byte) bool { + return cast(s).HasAny(items...) +} + +// Clone returns a new set which is a copy of the current set. +func (s Byte) Clone() Byte { + return Byte(cast(s).Clone()) +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 Byte) Difference(s2 Byte) Byte { + return Byte(cast(s1).Difference(cast(s2))) +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 Byte) SymmetricDifference(s2 Byte) Byte { + return Byte(cast(s1).SymmetricDifference(cast(s2))) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 Byte) Union(s2 Byte) Byte { + return Byte(cast(s1).Union(cast(s2))) +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 Byte) Intersection(s2 Byte) Byte { + return Byte(cast(s1).Intersection(cast(s2))) +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 Byte) IsSuperset(s2 Byte) bool { + return cast(s1).IsSuperset(cast(s2)) +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 Byte) Equal(s2 Byte) bool { + return cast(s1).Equal(cast(s2)) +} + +// List returns the contents as a sorted byte slice. +func (s Byte) List() []byte { + return List(cast(s)) +} + +// UnsortedList returns the slice with contents in random order. +func (s Byte) UnsortedList() []byte { + return cast(s).UnsortedList() +} + +// PopAny returns a single element from the set. +func (s Byte) PopAny() (byte, bool) { + return cast(s).PopAny() +} + +// Len returns the size of the set. +func (s Byte) Len() int { + return len(s) +} diff --git a/proxy/wireguard/iptables/sets/doc.go b/proxy/wireguard/iptables/sets/doc.go new file mode 100644 index 000000000000..194883390cf2 --- /dev/null +++ b/proxy/wireguard/iptables/sets/doc.go @@ -0,0 +1,19 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package sets has generic set and specified sets. Generic set will +// replace specified ones over time. And specific ones are deprecated. +package sets diff --git a/proxy/wireguard/iptables/sets/empty.go b/proxy/wireguard/iptables/sets/empty.go new file mode 100644 index 000000000000..fbb1df06d922 --- /dev/null +++ b/proxy/wireguard/iptables/sets/empty.go @@ -0,0 +1,21 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// Empty is public since it is used by some internal API objects for conversions between external +// string arrays and internal sets, and conversion logic requires public types today. +type Empty struct{} diff --git a/proxy/wireguard/iptables/sets/int.go b/proxy/wireguard/iptables/sets/int.go new file mode 100644 index 000000000000..5876fc9deb96 --- /dev/null +++ b/proxy/wireguard/iptables/sets/int.go @@ -0,0 +1,137 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// Int is a set of ints, implemented via map[int]struct{} for minimal memory consumption. +// +// Deprecated: use generic Set instead. +// new ways: +// s1 := Set[int]{} +// s2 := New[int]() +type Int map[int]Empty + +// NewInt creates a Int from a list of values. +func NewInt(items ...int) Int { + return Int(New[int](items...)) +} + +// IntKeySet creates a Int from a keys of a map[int](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func IntKeySet[T any](theMap map[int]T) Int { + return Int(KeySet(theMap)) +} + +// Insert adds items to the set. +func (s Int) Insert(items ...int) Int { + return Int(cast(s).Insert(items...)) +} + +// Delete removes all items from the set. +func (s Int) Delete(items ...int) Int { + return Int(cast(s).Delete(items...)) +} + +// Has returns true if and only if item is contained in the set. +func (s Int) Has(item int) bool { + return cast(s).Has(item) +} + +// HasAll returns true if and only if all items are contained in the set. +func (s Int) HasAll(items ...int) bool { + return cast(s).HasAll(items...) +} + +// HasAny returns true if any items are contained in the set. +func (s Int) HasAny(items ...int) bool { + return cast(s).HasAny(items...) +} + +// Clone returns a new set which is a copy of the current set. +func (s Int) Clone() Int { + return Int(cast(s).Clone()) +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 Int) Difference(s2 Int) Int { + return Int(cast(s1).Difference(cast(s2))) +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 Int) SymmetricDifference(s2 Int) Int { + return Int(cast(s1).SymmetricDifference(cast(s2))) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 Int) Union(s2 Int) Int { + return Int(cast(s1).Union(cast(s2))) +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 Int) Intersection(s2 Int) Int { + return Int(cast(s1).Intersection(cast(s2))) +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 Int) IsSuperset(s2 Int) bool { + return cast(s1).IsSuperset(cast(s2)) +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 Int) Equal(s2 Int) bool { + return cast(s1).Equal(cast(s2)) +} + +// List returns the contents as a sorted int slice. +func (s Int) List() []int { + return List(cast(s)) +} + +// UnsortedList returns the slice with contents in random order. +func (s Int) UnsortedList() []int { + return cast(s).UnsortedList() +} + +// PopAny returns a single element from the set. +func (s Int) PopAny() (int, bool) { + return cast(s).PopAny() +} + +// Len returns the size of the set. +func (s Int) Len() int { + return len(s) +} diff --git a/proxy/wireguard/iptables/sets/int32.go b/proxy/wireguard/iptables/sets/int32.go new file mode 100644 index 000000000000..2c640c5d0f1d --- /dev/null +++ b/proxy/wireguard/iptables/sets/int32.go @@ -0,0 +1,137 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// Int32 is a set of int32s, implemented via map[int32]struct{} for minimal memory consumption. +// +// Deprecated: use generic Set instead. +// new ways: +// s1 := Set[int32]{} +// s2 := New[int32]() +type Int32 map[int32]Empty + +// NewInt32 creates a Int32 from a list of values. +func NewInt32(items ...int32) Int32 { + return Int32(New[int32](items...)) +} + +// Int32KeySet creates a Int32 from a keys of a map[int32](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func Int32KeySet[T any](theMap map[int32]T) Int32 { + return Int32(KeySet(theMap)) +} + +// Insert adds items to the set. +func (s Int32) Insert(items ...int32) Int32 { + return Int32(cast(s).Insert(items...)) +} + +// Delete removes all items from the set. +func (s Int32) Delete(items ...int32) Int32 { + return Int32(cast(s).Delete(items...)) +} + +// Has returns true if and only if item is contained in the set. +func (s Int32) Has(item int32) bool { + return cast(s).Has(item) +} + +// HasAll returns true if and only if all items are contained in the set. +func (s Int32) HasAll(items ...int32) bool { + return cast(s).HasAll(items...) +} + +// HasAny returns true if any items are contained in the set. +func (s Int32) HasAny(items ...int32) bool { + return cast(s).HasAny(items...) +} + +// Clone returns a new set which is a copy of the current set. +func (s Int32) Clone() Int32 { + return Int32(cast(s).Clone()) +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 Int32) Difference(s2 Int32) Int32 { + return Int32(cast(s1).Difference(cast(s2))) +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 Int32) SymmetricDifference(s2 Int32) Int32 { + return Int32(cast(s1).SymmetricDifference(cast(s2))) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 Int32) Union(s2 Int32) Int32 { + return Int32(cast(s1).Union(cast(s2))) +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 Int32) Intersection(s2 Int32) Int32 { + return Int32(cast(s1).Intersection(cast(s2))) +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 Int32) IsSuperset(s2 Int32) bool { + return cast(s1).IsSuperset(cast(s2)) +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 Int32) Equal(s2 Int32) bool { + return cast(s1).Equal(cast(s2)) +} + +// List returns the contents as a sorted int32 slice. +func (s Int32) List() []int32 { + return List(cast(s)) +} + +// UnsortedList returns the slice with contents in random order. +func (s Int32) UnsortedList() []int32 { + return cast(s).UnsortedList() +} + +// PopAny returns a single element from the set. +func (s Int32) PopAny() (int32, bool) { + return cast(s).PopAny() +} + +// Len returns the size of the set. +func (s Int32) Len() int { + return len(s) +} diff --git a/proxy/wireguard/iptables/sets/int64.go b/proxy/wireguard/iptables/sets/int64.go new file mode 100644 index 000000000000..bf3eb3ffa259 --- /dev/null +++ b/proxy/wireguard/iptables/sets/int64.go @@ -0,0 +1,137 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// Int64 is a set of int64s, implemented via map[int64]struct{} for minimal memory consumption. +// +// Deprecated: use generic Set instead. +// new ways: +// s1 := Set[int64]{} +// s2 := New[int64]() +type Int64 map[int64]Empty + +// NewInt64 creates a Int64 from a list of values. +func NewInt64(items ...int64) Int64 { + return Int64(New[int64](items...)) +} + +// Int64KeySet creates a Int64 from a keys of a map[int64](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func Int64KeySet[T any](theMap map[int64]T) Int64 { + return Int64(KeySet(theMap)) +} + +// Insert adds items to the set. +func (s Int64) Insert(items ...int64) Int64 { + return Int64(cast(s).Insert(items...)) +} + +// Delete removes all items from the set. +func (s Int64) Delete(items ...int64) Int64 { + return Int64(cast(s).Delete(items...)) +} + +// Has returns true if and only if item is contained in the set. +func (s Int64) Has(item int64) bool { + return cast(s).Has(item) +} + +// HasAll returns true if and only if all items are contained in the set. +func (s Int64) HasAll(items ...int64) bool { + return cast(s).HasAll(items...) +} + +// HasAny returns true if any items are contained in the set. +func (s Int64) HasAny(items ...int64) bool { + return cast(s).HasAny(items...) +} + +// Clone returns a new set which is a copy of the current set. +func (s Int64) Clone() Int64 { + return Int64(cast(s).Clone()) +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 Int64) Difference(s2 Int64) Int64 { + return Int64(cast(s1).Difference(cast(s2))) +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 Int64) SymmetricDifference(s2 Int64) Int64 { + return Int64(cast(s1).SymmetricDifference(cast(s2))) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 Int64) Union(s2 Int64) Int64 { + return Int64(cast(s1).Union(cast(s2))) +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 Int64) Intersection(s2 Int64) Int64 { + return Int64(cast(s1).Intersection(cast(s2))) +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 Int64) IsSuperset(s2 Int64) bool { + return cast(s1).IsSuperset(cast(s2)) +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 Int64) Equal(s2 Int64) bool { + return cast(s1).Equal(cast(s2)) +} + +// List returns the contents as a sorted int64 slice. +func (s Int64) List() []int64 { + return List(cast(s)) +} + +// UnsortedList returns the slice with contents in random order. +func (s Int64) UnsortedList() []int64 { + return cast(s).UnsortedList() +} + +// PopAny returns a single element from the set. +func (s Int64) PopAny() (int64, bool) { + return cast(s).PopAny() +} + +// Len returns the size of the set. +func (s Int64) Len() int { + return len(s) +} diff --git a/proxy/wireguard/iptables/sets/ordered.go b/proxy/wireguard/iptables/sets/ordered.go new file mode 100644 index 000000000000..443dac62eb35 --- /dev/null +++ b/proxy/wireguard/iptables/sets/ordered.go @@ -0,0 +1,53 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type ordered interface { + integer | float | ~string +} + +// integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type integer interface { + signed | unsigned +} + +// float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type float interface { + ~float32 | ~float64 +} + +// signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} diff --git a/proxy/wireguard/iptables/sets/set.go b/proxy/wireguard/iptables/sets/set.go new file mode 100644 index 000000000000..d50526f4262c --- /dev/null +++ b/proxy/wireguard/iptables/sets/set.go @@ -0,0 +1,241 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +import ( + "sort" +) + +// Set is a set of the same type elements, implemented via map[comparable]struct{} for minimal memory consumption. +type Set[T comparable] map[T]Empty + +// cast transforms specified set to generic Set[T]. +func cast[T comparable](s map[T]Empty) Set[T] { return s } + +// New creates a Set from a list of values. +// NOTE: type param must be explicitly instantiated if given items are empty. +func New[T comparable](items ...T) Set[T] { + ss := make(Set[T], len(items)) + ss.Insert(items...) + return ss +} + +// KeySet creates a Set from a keys of a map[comparable](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func KeySet[T comparable, V any](theMap map[T]V) Set[T] { + ret := Set[T]{} + for keyValue := range theMap { + ret.Insert(keyValue) + } + return ret +} + +// Insert adds items to the set. +func (s Set[T]) Insert(items ...T) Set[T] { + for _, item := range items { + s[item] = Empty{} + } + return s +} + +func Insert[T comparable](set Set[T], items ...T) Set[T] { + return set.Insert(items...) +} + +// Delete removes all items from the set. +func (s Set[T]) Delete(items ...T) Set[T] { + for _, item := range items { + delete(s, item) + } + return s +} + +// Clear empties the set. +// It is preferable to replace the set with a newly constructed set, +// but not all callers can do that (when there are other references to the map). +// In some cases the set *won't* be fully cleared, e.g. a Set[float32] containing NaN +// can't be cleared because NaN can't be removed. +// For sets containing items of a type that is reflexive for ==, +// this is optimized to a single call to runtime.mapclear(). +func (s Set[T]) Clear() Set[T] { + for key := range s { + delete(s, key) + } + return s +} + +// Has returns true if and only if item is contained in the set. +func (s Set[T]) Has(item T) bool { + _, contained := s[item] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (s Set[T]) HasAll(items ...T) bool { + for _, item := range items { + if !s.Has(item) { + return false + } + } + return true +} + +// HasAny returns true if any items are contained in the set. +func (s Set[T]) HasAny(items ...T) bool { + for _, item := range items { + if s.Has(item) { + return true + } + } + return false +} + +// Clone returns a new set which is a copy of the current set. +func (s Set[T]) Clone() Set[T] { + result := make(Set[T], len(s)) + for key := range s { + result.Insert(key) + } + return result +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 Set[T]) Difference(s2 Set[T]) Set[T] { + result := New[T]() + for key := range s1 { + if !s2.Has(key) { + result.Insert(key) + } + } + return result +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 Set[T]) SymmetricDifference(s2 Set[T]) Set[T] { + return s1.Difference(s2).Union(s2.Difference(s1)) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 Set[T]) Union(s2 Set[T]) Set[T] { + result := s1.Clone() + for key := range s2 { + result.Insert(key) + } + return result +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 Set[T]) Intersection(s2 Set[T]) Set[T] { + var walk, other Set[T] + result := New[T]() + if s1.Len() < s2.Len() { + walk = s1 + other = s2 + } else { + walk = s2 + other = s1 + } + for key := range walk { + if other.Has(key) { + result.Insert(key) + } + } + return result +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 Set[T]) IsSuperset(s2 Set[T]) bool { + for item := range s2 { + if !s1.Has(item) { + return false + } + } + return true +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 Set[T]) Equal(s2 Set[T]) bool { + return len(s1) == len(s2) && s1.IsSuperset(s2) +} + +type sortableSliceOfGeneric[T ordered] []T + +func (g sortableSliceOfGeneric[T]) Len() int { return len(g) } +func (g sortableSliceOfGeneric[T]) Less(i, j int) bool { return less[T](g[i], g[j]) } +func (g sortableSliceOfGeneric[T]) Swap(i, j int) { g[i], g[j] = g[j], g[i] } + +// List returns the contents as a sorted T slice. +// +// This is a separate function and not a method because not all types supported +// by Generic are ordered and only those can be sorted. +func List[T ordered](s Set[T]) []T { + res := make(sortableSliceOfGeneric[T], 0, len(s)) + for key := range s { + res = append(res, key) + } + sort.Sort(res) + return res +} + +// UnsortedList returns the slice with contents in random order. +func (s Set[T]) UnsortedList() []T { + res := make([]T, 0, len(s)) + for key := range s { + res = append(res, key) + } + return res +} + +// PopAny returns a single element from the set. +func (s Set[T]) PopAny() (T, bool) { + for key := range s { + s.Delete(key) + return key, true + } + var zeroValue T + return zeroValue, false +} + +// Len returns the size of the set. +func (s Set[T]) Len() int { + return len(s) +} + +func less[T ordered](lhs, rhs T) bool { + return lhs < rhs +} diff --git a/proxy/wireguard/iptables/sets/string.go b/proxy/wireguard/iptables/sets/string.go new file mode 100644 index 000000000000..1dab6d13cc79 --- /dev/null +++ b/proxy/wireguard/iptables/sets/string.go @@ -0,0 +1,137 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +// String is a set of strings, implemented via map[string]struct{} for minimal memory consumption. +// +// Deprecated: use generic Set instead. +// new ways: +// s1 := Set[string]{} +// s2 := New[string]() +type String map[string]Empty + +// NewString creates a String from a list of values. +func NewString(items ...string) String { + return String(New[string](items...)) +} + +// StringKeySet creates a String from a keys of a map[string](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func StringKeySet[T any](theMap map[string]T) String { + return String(KeySet(theMap)) +} + +// Insert adds items to the set. +func (s String) Insert(items ...string) String { + return String(cast(s).Insert(items...)) +} + +// Delete removes all items from the set. +func (s String) Delete(items ...string) String { + return String(cast(s).Delete(items...)) +} + +// Has returns true if and only if item is contained in the set. +func (s String) Has(item string) bool { + return cast(s).Has(item) +} + +// HasAll returns true if and only if all items are contained in the set. +func (s String) HasAll(items ...string) bool { + return cast(s).HasAll(items...) +} + +// HasAny returns true if any items are contained in the set. +func (s String) HasAny(items ...string) bool { + return cast(s).HasAny(items...) +} + +// Clone returns a new set which is a copy of the current set. +func (s String) Clone() String { + return String(cast(s).Clone()) +} + +// Difference returns a set of objects that are not in s2. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s1 String) Difference(s2 String) String { + return String(cast(s1).Difference(cast(s2))) +} + +// SymmetricDifference returns a set of elements which are in either of the sets, but not in their intersection. +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.SymmetricDifference(s2) = {a3, a4, a5} +// s2.SymmetricDifference(s1) = {a3, a4, a5} +func (s1 String) SymmetricDifference(s2 String) String { + return String(cast(s1).SymmetricDifference(cast(s2))) +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (s1 String) Union(s2 String) String { + return String(cast(s1).Union(cast(s2))) +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (s1 String) Intersection(s2 String) String { + return String(cast(s1).Intersection(cast(s2))) +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 String) IsSuperset(s2 String) bool { + return cast(s1).IsSuperset(cast(s2)) +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 String) Equal(s2 String) bool { + return cast(s1).Equal(cast(s2)) +} + +// List returns the contents as a sorted string slice. +func (s String) List() []string { + return List(cast(s)) +} + +// UnsortedList returns the slice with contents in random order. +func (s String) UnsortedList() []string { + return cast(s).UnsortedList() +} + +// PopAny returns a single element from the set. +func (s String) PopAny() (string, bool) { + return cast(s).PopAny() +} + +// Len returns the size of the set. +func (s String) Len() int { + return len(s) +} diff --git a/proxy/wireguard/iptables/version/doc.go b/proxy/wireguard/iptables/version/doc.go new file mode 100644 index 000000000000..5b2b22b6d00c --- /dev/null +++ b/proxy/wireguard/iptables/version/doc.go @@ -0,0 +1,18 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package version provides utilities for version number comparisons +package version // import "k8s.io/apimachinery/pkg/util/version" diff --git a/proxy/wireguard/iptables/version/version.go b/proxy/wireguard/iptables/version/version.go new file mode 100644 index 000000000000..79bb4cfcf193 --- /dev/null +++ b/proxy/wireguard/iptables/version/version.go @@ -0,0 +1,371 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "bytes" + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +// Version is an opaque representation of a version number +type Version struct { + components []uint + semver bool + preRelease string + buildMetadata string +} + +var ( + // versionMatchRE splits a version string into numeric and "extra" parts + versionMatchRE = regexp.MustCompile(`^\s*v?([0-9]+(?:\.[0-9]+)*)(.*)*$`) + // extraMatchRE splits the "extra" part of versionMatchRE into semver pre-release and build metadata; it does not validate the "no leading zeroes" constraint for pre-release + extraMatchRE = regexp.MustCompile(`^(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?\s*$`) +) + +func parse(str string, semver bool) (*Version, error) { + parts := versionMatchRE.FindStringSubmatch(str) + if parts == nil { + return nil, fmt.Errorf("could not parse %q as version", str) + } + numbers, extra := parts[1], parts[2] + + components := strings.Split(numbers, ".") + if (semver && len(components) != 3) || (!semver && len(components) < 2) { + return nil, fmt.Errorf("illegal version string %q", str) + } + + v := &Version{ + components: make([]uint, len(components)), + semver: semver, + } + for i, comp := range components { + if (i == 0 || semver) && strings.HasPrefix(comp, "0") && comp != "0" { + return nil, fmt.Errorf("illegal zero-prefixed version component %q in %q", comp, str) + } + num, err := strconv.ParseUint(comp, 10, 0) + if err != nil { + return nil, fmt.Errorf("illegal non-numeric version component %q in %q: %v", comp, str, err) + } + v.components[i] = uint(num) + } + + if semver && extra != "" { + extraParts := extraMatchRE.FindStringSubmatch(extra) + if extraParts == nil { + return nil, fmt.Errorf("could not parse pre-release/metadata (%s) in version %q", extra, str) + } + v.preRelease, v.buildMetadata = extraParts[1], extraParts[2] + + for _, comp := range strings.Split(v.preRelease, ".") { + if _, err := strconv.ParseUint(comp, 10, 0); err == nil { + if strings.HasPrefix(comp, "0") && comp != "0" { + return nil, fmt.Errorf("illegal zero-prefixed version component %q in %q", comp, str) + } + } + } + } + + return v, nil +} + +// HighestSupportedVersion returns the highest supported version +// This function assumes that the highest supported version must be v1.x. +func HighestSupportedVersion(versions []string) (*Version, error) { + if len(versions) == 0 { + return nil, errors.New("empty array for supported versions") + } + + var ( + highestSupportedVersion *Version + theErr error + ) + + for i := len(versions) - 1; i >= 0; i-- { + currentHighestVer, err := ParseGeneric(versions[i]) + if err != nil { + theErr = err + continue + } + + if currentHighestVer.Major() > 1 { + continue + } + + if highestSupportedVersion == nil || highestSupportedVersion.LessThan(currentHighestVer) { + highestSupportedVersion = currentHighestVer + } + } + + if highestSupportedVersion == nil { + return nil, fmt.Errorf( + "could not find a highest supported version from versions (%v) reported: %+v", + versions, theErr) + } + + if highestSupportedVersion.Major() != 1 { + return nil, fmt.Errorf("highest supported version reported is %v, must be v1.x", highestSupportedVersion) + } + + return highestSupportedVersion, nil +} + +// ParseGeneric parses a "generic" version string. The version string must consist of two +// or more dot-separated numeric fields (the first of which can't have leading zeroes), +// followed by arbitrary uninterpreted data (which need not be separated from the final +// numeric field by punctuation). For convenience, leading and trailing whitespace is +// ignored, and the version can be preceded by the letter "v". See also ParseSemantic. +func ParseGeneric(str string) (*Version, error) { + return parse(str, false) +} + +// MustParseGeneric is like ParseGeneric except that it panics on error +func MustParseGeneric(str string) *Version { + v, err := ParseGeneric(str) + if err != nil { + panic(err) + } + return v +} + +// ParseSemantic parses a version string that exactly obeys the syntax and semantics of +// the "Semantic Versioning" specification (http://semver.org/) (although it ignores +// leading and trailing whitespace, and allows the version to be preceded by "v"). For +// version strings that are not guaranteed to obey the Semantic Versioning syntax, use +// ParseGeneric. +func ParseSemantic(str string) (*Version, error) { + return parse(str, true) +} + +// MustParseSemantic is like ParseSemantic except that it panics on error +func MustParseSemantic(str string) *Version { + v, err := ParseSemantic(str) + if err != nil { + panic(err) + } + return v +} + +// MajorMinor returns a version with the provided major and minor version. +func MajorMinor(major, minor uint) *Version { + return &Version{components: []uint{major, minor}} +} + +// Major returns the major release number +func (v *Version) Major() uint { + return v.components[0] +} + +// Minor returns the minor release number +func (v *Version) Minor() uint { + return v.components[1] +} + +// Patch returns the patch release number if v is a Semantic Version, or 0 +func (v *Version) Patch() uint { + if len(v.components) < 3 { + return 0 + } + return v.components[2] +} + +// BuildMetadata returns the build metadata, if v is a Semantic Version, or "" +func (v *Version) BuildMetadata() string { + return v.buildMetadata +} + +// PreRelease returns the prerelease metadata, if v is a Semantic Version, or "" +func (v *Version) PreRelease() string { + return v.preRelease +} + +// Components returns the version number components +func (v *Version) Components() []uint { + return v.components +} + +// WithMajor returns copy of the version object with requested major number +func (v *Version) WithMajor(major uint) *Version { + result := *v + result.components = []uint{major, v.Minor(), v.Patch()} + return &result +} + +// WithMinor returns copy of the version object with requested minor number +func (v *Version) WithMinor(minor uint) *Version { + result := *v + result.components = []uint{v.Major(), minor, v.Patch()} + return &result +} + +// WithPatch returns copy of the version object with requested patch number +func (v *Version) WithPatch(patch uint) *Version { + result := *v + result.components = []uint{v.Major(), v.Minor(), patch} + return &result +} + +// WithPreRelease returns copy of the version object with requested prerelease +func (v *Version) WithPreRelease(preRelease string) *Version { + result := *v + result.components = []uint{v.Major(), v.Minor(), v.Patch()} + result.preRelease = preRelease + return &result +} + +// WithBuildMetadata returns copy of the version object with requested buildMetadata +func (v *Version) WithBuildMetadata(buildMetadata string) *Version { + result := *v + result.components = []uint{v.Major(), v.Minor(), v.Patch()} + result.buildMetadata = buildMetadata + return &result +} + +// String converts a Version back to a string; note that for versions parsed with +// ParseGeneric, this will not include the trailing uninterpreted portion of the version +// number. +func (v *Version) String() string { + if v == nil { + return "" + } + var buffer bytes.Buffer + + for i, comp := range v.components { + if i > 0 { + buffer.WriteString(".") + } + buffer.WriteString(fmt.Sprintf("%d", comp)) + } + if v.preRelease != "" { + buffer.WriteString("-") + buffer.WriteString(v.preRelease) + } + if v.buildMetadata != "" { + buffer.WriteString("+") + buffer.WriteString(v.buildMetadata) + } + + return buffer.String() +} + +// compareInternal returns -1 if v is less than other, 1 if it is greater than other, or 0 +// if they are equal +func (v *Version) compareInternal(other *Version) int { + vLen := len(v.components) + oLen := len(other.components) + for i := 0; i < vLen && i < oLen; i++ { + switch { + case other.components[i] < v.components[i]: + return 1 + case other.components[i] > v.components[i]: + return -1 + } + } + + // If components are common but one has more items and they are not zeros, it is bigger + switch { + case oLen < vLen && !onlyZeros(v.components[oLen:]): + return 1 + case oLen > vLen && !onlyZeros(other.components[vLen:]): + return -1 + } + + if !v.semver || !other.semver { + return 0 + } + + switch { + case v.preRelease == "" && other.preRelease != "": + return 1 + case v.preRelease != "" && other.preRelease == "": + return -1 + case v.preRelease == other.preRelease: // includes case where both are "" + return 0 + } + + vPR := strings.Split(v.preRelease, ".") + oPR := strings.Split(other.preRelease, ".") + for i := 0; i < len(vPR) && i < len(oPR); i++ { + vNum, err := strconv.ParseUint(vPR[i], 10, 0) + if err == nil { + oNum, err := strconv.ParseUint(oPR[i], 10, 0) + if err == nil { + switch { + case oNum < vNum: + return 1 + case oNum > vNum: + return -1 + default: + continue + } + } + } + if oPR[i] < vPR[i] { + return 1 + } else if oPR[i] > vPR[i] { + return -1 + } + } + + switch { + case len(oPR) < len(vPR): + return 1 + case len(oPR) > len(vPR): + return -1 + } + + return 0 +} + +// returns false if array contain any non-zero element +func onlyZeros(array []uint) bool { + for _, num := range array { + if num != 0 { + return false + } + } + return true +} + +// AtLeast tests if a version is at least equal to a given minimum version. If both +// Versions are Semantic Versions, this will use the Semantic Version comparison +// algorithm. Otherwise, it will compare only the numeric components, with non-present +// components being considered "0" (ie, "1.4" is equal to "1.4.0"). +func (v *Version) AtLeast(min *Version) bool { + return v.compareInternal(min) != -1 +} + +// LessThan tests if a version is less than a given version. (It is exactly the opposite +// of AtLeast, for situations where asking "is v too old?" makes more sense than asking +// "is v new enough?".) +func (v *Version) LessThan(other *Version) bool { + return v.compareInternal(other) == -1 +} + +// Compare compares v against a version string (which will be parsed as either Semantic +// or non-Semantic depending on v). On success it returns -1 if v is less than other, 1 if +// it is greater than other, or 0 if they are equal. +func (v *Version) Compare(other string) (int, error) { + ov, err := parse(other, v.semver) + if err != nil { + return 0, err + } + return v.compareInternal(ov), nil +} diff --git a/proxy/wireguard/iptables/version/version_test.go b/proxy/wireguard/iptables/version/version_test.go new file mode 100644 index 000000000000..d929812f9789 --- /dev/null +++ b/proxy/wireguard/iptables/version/version_test.go @@ -0,0 +1,453 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "fmt" + "reflect" + "testing" +) + +type testItem struct { + version string + unparsed string + equalsPrev bool +} + +func testOne(v *Version, item, prev testItem) error { + str := v.String() + if item.unparsed == "" { + if str != item.version { + return fmt.Errorf("bad round-trip: %q -> %q", item.version, str) + } + } else { + if str != item.unparsed { + return fmt.Errorf("bad unparse: %q -> %q, expected %q", item.version, str, item.unparsed) + } + } + + if prev.version != "" { + cmp, err := v.Compare(prev.version) + if err != nil { + return fmt.Errorf("unexpected parse error: %v", err) + } + rv, err := parse(prev.version, v.semver) + if err != nil { + return fmt.Errorf("unexpected parse error: %v", err) + } + rcmp, err := rv.Compare(item.version) + if err != nil { + return fmt.Errorf("unexpected parse error: %v", err) + } + + switch { + case cmp == -1: + return fmt.Errorf("unexpected ordering %q < %q", item.version, prev.version) + case cmp == 0 && !item.equalsPrev: + return fmt.Errorf("unexpected comparison %q == %q", item.version, prev.version) + case cmp == 1 && item.equalsPrev: + return fmt.Errorf("unexpected comparison %q != %q", item.version, prev.version) + case cmp != -rcmp: + return fmt.Errorf("unexpected reverse comparison %q <=> %q %v %v %v %v", item.version, prev.version, cmp, rcmp, v.Components(), rv.Components()) + } + } + + return nil +} + +func TestSemanticVersions(t *testing.T) { + tests := []testItem{ + // This is every version string that appears in the 2.0 semver spec, + // sorted in strictly increasing order except as noted. + {version: "0.1.0"}, + {version: "1.0.0-0.3.7"}, + {version: "1.0.0-alpha"}, + {version: "1.0.0-alpha+001", equalsPrev: true}, + {version: "1.0.0-alpha.1"}, + {version: "1.0.0-alpha.beta"}, + {version: "1.0.0-beta"}, + {version: "1.0.0-beta+exp.sha.5114f85", equalsPrev: true}, + {version: "1.0.0-beta.2"}, + {version: "1.0.0-beta.11"}, + {version: "1.0.0-rc.1"}, + {version: "1.0.0-x.7.z.92"}, + {version: "1.0.0"}, + {version: "1.0.0+20130313144700", equalsPrev: true}, + {version: "1.8.0-alpha.3"}, + {version: "1.8.0-alpha.3.673+73326ef01d2d7c"}, + {version: "1.9.0"}, + {version: "1.10.0"}, + {version: "1.11.0"}, + {version: "2.0.0"}, + {version: "2.1.0"}, + {version: "2.1.1"}, + {version: "42.0.0"}, + + // We also allow whitespace and "v" prefix + {version: " 42.0.0", unparsed: "42.0.0", equalsPrev: true}, + {version: "\t42.0.0 ", unparsed: "42.0.0", equalsPrev: true}, + {version: "43.0.0-1", unparsed: "43.0.0-1"}, + {version: "43.0.0-1 ", unparsed: "43.0.0-1", equalsPrev: true}, + {version: "v43.0.0-1", unparsed: "43.0.0-1", equalsPrev: true}, + {version: " v43.0.0", unparsed: "43.0.0"}, + {version: " 43.0.0 ", unparsed: "43.0.0", equalsPrev: true}, + } + + var prev testItem + for _, item := range tests { + v, err := ParseSemantic(item.version) + if err != nil { + t.Errorf("unexpected parse error: %v", err) + continue + } + err = testOne(v, item, prev) + if err != nil { + t.Errorf("%v", err) + } + prev = item + } +} + +func TestBadSemanticVersions(t *testing.T) { + tests := []string{ + // "MUST take the form X.Y.Z" + "1", + "1.2", + "1.2.3.4", + ".2.3", + "1..3", + "1.2.", + "", + "..", + // "where X, Y, and Z are non-negative integers" + "-1.2.3", + "1.-2.3", + "1.2.-3", + "1a.2.3", + "1.2a.3", + "1.2.3a", + "a1.2.3", + "a.b.c", + "1 .2.3", + "1. 2.3", + // "and MUST NOT contain leading zeroes." + "01.2.3", + "1.02.3", + "1.2.03", + // "[pre-release] identifiers MUST comprise only ASCII alphanumerics and hyphen" + "1.2.3-/", + // "[pre-release] identifiers MUST NOT be empty" + "1.2.3-", + "1.2.3-.", + "1.2.3-foo.", + "1.2.3-.foo", + // "Numeric [pre-release] identifiers MUST NOT include leading zeroes" + "1.2.3-01", + // "[build metadata] identifiers MUST comprise only ASCII alphanumerics and hyphen" + "1.2.3+/", + // "[build metadata] identifiers MUST NOT be empty" + "1.2.3+", + "1.2.3+.", + "1.2.3+foo.", + "1.2.3+.foo", + + // whitespace/"v"-prefix checks + "v 1.2.3", + "vv1.2.3", + } + + for i := range tests { + _, err := ParseSemantic(tests[i]) + if err == nil { + t.Errorf("unexpected success parsing invalid semver %q", tests[i]) + } + } +} + +func TestGenericVersions(t *testing.T) { + tests := []testItem{ + // This is all of the strings from TestSemanticVersions, plus some strings + // from TestBadSemanticVersions that should parse as generic versions, + // plus some additional strings. + {version: "0.1.0", unparsed: "0.1.0"}, + {version: "1.0.0-0.3.7", unparsed: "1.0.0"}, + {version: "1.0.0-alpha", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0-alpha+001", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0-alpha.1", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0-alpha.beta", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0.beta", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0-beta+exp.sha.5114f85", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0.beta.2", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0.beta.11", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0.rc.1", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0-x.7.z.92", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.0.0+20130313144700", unparsed: "1.0.0", equalsPrev: true}, + {version: "1.2", unparsed: "1.2"}, + {version: "1.2a.3", unparsed: "1.2", equalsPrev: true}, + {version: "1.2.3", unparsed: "1.2.3"}, + {version: "1.2.3.0", unparsed: "1.2.3.0", equalsPrev: true}, + {version: "1.2.3a", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3-foo.", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3-.foo", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3-01", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3+", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3+foo.", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3+.foo", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.02.3", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.03", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.003", unparsed: "1.2.3", equalsPrev: true}, + {version: "1.2.3.4", unparsed: "1.2.3.4"}, + {version: "1.2.3.4b3", unparsed: "1.2.3.4", equalsPrev: true}, + {version: "1.2.3.4.5", unparsed: "1.2.3.4.5"}, + {version: "1.9.0", unparsed: "1.9.0"}, + {version: "1.9.0.0.0.0.0.0", unparsed: "1.9.0.0.0.0.0.0", equalsPrev: true}, + {version: "1.10.0", unparsed: "1.10.0"}, + {version: "1.11.0", unparsed: "1.11.0"}, + {version: "1.11.0.0.5", unparsed: "1.11.0.0.5"}, + {version: "2.0.0", unparsed: "2.0.0"}, + {version: "2.1.0", unparsed: "2.1.0"}, + {version: "2.1.1", unparsed: "2.1.1"}, + {version: "42.0.0", unparsed: "42.0.0"}, + {version: " 42.0.0", unparsed: "42.0.0", equalsPrev: true}, + {version: "\t42.0.0 ", unparsed: "42.0.0", equalsPrev: true}, + {version: "42.0.0-1", unparsed: "42.0.0", equalsPrev: true}, + {version: "42.0.0-1 ", unparsed: "42.0.0", equalsPrev: true}, + {version: "v42.0.0-1", unparsed: "42.0.0", equalsPrev: true}, + {version: " v43.0.0", unparsed: "43.0.0"}, + {version: " 43.0.0 ", unparsed: "43.0.0", equalsPrev: true}, + } + + var prev testItem + for _, item := range tests { + v, err := ParseGeneric(item.version) + if err != nil { + t.Errorf("unexpected parse error: %v", err) + continue + } + err = testOne(v, item, prev) + if err != nil { + t.Errorf("%v", err) + } + prev = item + } +} + +func TestBadGenericVersions(t *testing.T) { + tests := []string{ + "1", + "01.2.3", + "-1.2.3", + "1.-2.3", + ".2.3", + "1..3", + "1a.2.3", + "a1.2.3", + "1 .2.3", + "1. 2.3", + "1.bob", + "bob", + "v 1.2.3", + "vv1.2.3", + "", + ".", + } + + for i := range tests { + _, err := ParseGeneric(tests[i]) + if err == nil { + t.Errorf("unexpected success parsing invalid version %q", tests[i]) + } + } +} + +func TestComponents(t *testing.T) { + tests := []struct { + version string + semver bool + expectedComponents []uint + expectedMajor uint + expectedMinor uint + expectedPatch uint + expectedPreRelease string + expectedBuildMetadata string + }{ + { + version: "1.0.2", + semver: true, + expectedComponents: []uint{1, 0, 2}, + expectedMajor: 1, + expectedMinor: 0, + expectedPatch: 2, + }, + { + version: "1.0.2-alpha+001", + semver: true, + expectedComponents: []uint{1, 0, 2}, + expectedMajor: 1, + expectedMinor: 0, + expectedPatch: 2, + expectedPreRelease: "alpha", + expectedBuildMetadata: "001", + }, + { + version: "1.2", + semver: false, + expectedComponents: []uint{1, 2}, + expectedMajor: 1, + expectedMinor: 2, + }, + { + version: "1.0.2-beta+exp.sha.5114f85", + semver: true, + expectedComponents: []uint{1, 0, 2}, + expectedMajor: 1, + expectedMinor: 0, + expectedPatch: 2, + expectedPreRelease: "beta", + expectedBuildMetadata: "exp.sha.5114f85", + }, + } + + for _, test := range tests { + version, _ := parse(test.version, test.semver) + if !reflect.DeepEqual(test.expectedComponents, version.Components()) { + t.Error("parse returned un'expected components") + } + if test.expectedMajor != version.Major() { + t.Errorf("parse returned version.Major %d, expected %d", test.expectedMajor, version.Major()) + } + if test.expectedMinor != version.Minor() { + t.Errorf("parse returned version.Minor %d, expected %d", test.expectedMinor, version.Minor()) + } + if test.expectedPatch != version.Patch() { + t.Errorf("parse returned version.Patch %d, expected %d", test.expectedPatch, version.Patch()) + } + if test.expectedPreRelease != version.PreRelease() { + t.Errorf("parse returned version.PreRelease %s, expected %s", test.expectedPreRelease, version.PreRelease()) + } + if test.expectedBuildMetadata != version.BuildMetadata() { + t.Errorf("parse returned version.BuildMetadata %s, expected %s", test.expectedBuildMetadata, version.BuildMetadata()) + } + } +} + +func TestHighestSupportedVersion(t *testing.T) { + testCases := []struct { + versions []string + expectedHighestSupportedVersion string + shouldFail bool + }{ + { + versions: []string{"v1.0.0"}, + expectedHighestSupportedVersion: "1.0.0", + shouldFail: false, + }, + { + versions: []string{"0.3.0"}, + shouldFail: true, + }, + { + versions: []string{"0.2.0"}, + shouldFail: true, + }, + { + versions: []string{"1.0.0"}, + expectedHighestSupportedVersion: "1.0.0", + shouldFail: false, + }, + { + versions: []string{"v0.3.0"}, + shouldFail: true, + }, + { + versions: []string{"v0.2.0"}, + shouldFail: true, + }, + { + versions: []string{"0.2.0", "v0.3.0"}, + shouldFail: true, + }, + { + versions: []string{"0.2.0", "v1.0.0"}, + expectedHighestSupportedVersion: "1.0.0", + shouldFail: false, + }, + { + versions: []string{"0.2.0", "v1.2.3"}, + expectedHighestSupportedVersion: "1.2.3", + shouldFail: false, + }, + { + versions: []string{"v1.2.3", "v0.3.0"}, + expectedHighestSupportedVersion: "1.2.3", + shouldFail: false, + }, + { + versions: []string{"v1.2.3", "v0.3.0", "2.0.1"}, + expectedHighestSupportedVersion: "1.2.3", + shouldFail: false, + }, + { + versions: []string{"v1.2.3", "4.9.12", "v0.3.0", "2.0.1"}, + expectedHighestSupportedVersion: "1.2.3", + shouldFail: false, + }, + { + versions: []string{"4.9.12", "2.0.1"}, + expectedHighestSupportedVersion: "", + shouldFail: true, + }, + { + versions: []string{"v1.2.3", "boo", "v0.3.0", "2.0.1"}, + expectedHighestSupportedVersion: "1.2.3", + shouldFail: false, + }, + { + versions: []string{}, + expectedHighestSupportedVersion: "", + shouldFail: true, + }, + { + versions: []string{"var", "boo", "foo"}, + expectedHighestSupportedVersion: "", + shouldFail: true, + }, + } + + for _, tc := range testCases { + // Arrange & Act + actual, err := HighestSupportedVersion(tc.versions) + + // Assert + if tc.shouldFail && err == nil { + t.Fatalf("expecting highestSupportedVersion to fail, but got nil error for testcase: %#v", tc) + } + if !tc.shouldFail && err != nil { + t.Fatalf("unexpected error during ValidatePlugin for testcase: %#v\r\n err:%v", tc, err) + } + if tc.expectedHighestSupportedVersion != "" { + result, err := actual.Compare(tc.expectedHighestSupportedVersion) + if err != nil { + t.Fatalf("comparison failed with %v for testcase %#v", err, tc) + } + if result != 0 { + t.Fatalf("expectedHighestSupportedVersion %v, but got %v for tc: %#v", tc.expectedHighestSupportedVersion, actual, tc) + } + } + } +} diff --git a/proxy/wireguard/iptables/wait/backoff.go b/proxy/wireguard/iptables/wait/backoff.go new file mode 100644 index 000000000000..67cc4438f6f7 --- /dev/null +++ b/proxy/wireguard/iptables/wait/backoff.go @@ -0,0 +1,500 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "math" + "sync" + "time" + + "github.com/xtls/xray-core/proxy/wireguard/iptables/wait/clock" +) + +// Backoff holds parameters applied to a Backoff function. +type Backoff struct { + // The initial duration. + Duration time.Duration + // Duration is multiplied by factor each iteration, if factor is not zero + // and the limits imposed by Steps and Cap have not been reached. + // Should not be negative. + // The jitter does not contribute to the updates to the duration parameter. + Factor float64 + // The sleep at each iteration is the duration plus an additional + // amount chosen uniformly at random from the interval between + // zero and `jitter*duration`. + Jitter float64 + // The remaining number of iterations in which the duration + // parameter may change (but progress can be stopped earlier by + // hitting the cap). If not positive, the duration is not + // changed. Used for exponential backoff in combination with + // Factor and Cap. + Steps int + // A limit on revised values of the duration parameter. If a + // multiplication by the factor parameter would make the duration + // exceed the cap then the duration is set to the cap and the + // steps parameter is set to zero. + Cap time.Duration +} + +// Step returns an amount of time to sleep determined by the original +// Duration and Jitter. The backoff is mutated to update its Steps and +// Duration. A nil Backoff always has a zero-duration step. +func (b *Backoff) Step() time.Duration { + if b == nil { + return 0 + } + var nextDuration time.Duration + nextDuration, b.Duration, b.Steps = delay(b.Steps, b.Duration, b.Cap, b.Factor, b.Jitter) + return nextDuration +} + +// DelayFunc returns a function that will compute the next interval to +// wait given the arguments in b. It does not mutate the original backoff +// but the function is safe to use only from a single goroutine. +func (b Backoff) DelayFunc() DelayFunc { + steps := b.Steps + duration := b.Duration + cap := b.Cap + factor := b.Factor + jitter := b.Jitter + + return func() time.Duration { + var nextDuration time.Duration + // jitter is applied per step and is not cumulative over multiple steps + nextDuration, duration, steps = delay(steps, duration, cap, factor, jitter) + return nextDuration + } +} + +// Timer returns a timer implementation appropriate to this backoff's parameters +// for use with wait functions. +func (b Backoff) Timer() Timer { + if b.Steps > 1 || b.Jitter != 0 { + return &variableTimer{new: internalClock.NewTimer, fn: b.DelayFunc()} + } + if b.Duration > 0 { + return &fixedTimer{new: internalClock.NewTicker, interval: b.Duration} + } + return newNoopTimer() +} + +// delay implements the core delay algorithm used in this package. +func delay(steps int, duration, cap time.Duration, factor, jitter float64) (_, next time.Duration, nextSteps int) { + // when steps is non-positive, do not alter the base duration + if steps < 1 { + if jitter > 0 { + return Jitter(duration, jitter), duration, 0 + } + return duration, duration, 0 + } + steps-- + + // calculate the next step's interval + if factor != 0 { + next = time.Duration(float64(duration) * factor) + if cap > 0 && next > cap { + next = cap + steps = 0 + } + } else { + next = duration + } + + // add jitter for this step + if jitter > 0 { + duration = Jitter(duration, jitter) + } + + return duration, next, steps +} + +// DelayWithReset returns a DelayFunc that will return the appropriate next interval to +// wait. Every resetInterval the backoff parameters are reset to their initial state. +// This method is safe to invoke from multiple goroutines, but all calls will advance +// the backoff state when Factor is set. If Factor is zero, this method is the same as +// invoking b.DelayFunc() since Steps has no impact without Factor. If resetInterval is +// zero no backoff will be performed as the same calling DelayFunc with a zero factor +// and steps. +func (b Backoff) DelayWithReset(c clock.Clock, resetInterval time.Duration) DelayFunc { + if b.Factor <= 0 { + return b.DelayFunc() + } + if resetInterval <= 0 { + b.Steps = 0 + b.Factor = 0 + return b.DelayFunc() + } + return (&backoffManager{ + backoff: b, + initialBackoff: b, + resetInterval: resetInterval, + + clock: c, + lastStart: c.Now(), + timer: nil, + }).Step +} + +// Until loops until stop channel is closed, running f every period. +// +// Until is syntactic sugar on top of JitterUntil with zero jitter factor and +// with sliding = true (which means the timer for period starts after the f +// completes). +func Until(f func(), period time.Duration, stopCh <-chan struct{}) { + JitterUntil(f, period, 0.0, true, stopCh) +} + +// UntilWithContext loops until context is done, running f every period. +// +// UntilWithContext is syntactic sugar on top of JitterUntilWithContext +// with zero jitter factor and with sliding = true (which means the timer +// for period starts after the f completes). +func UntilWithContext(ctx context.Context, f func(context.Context), period time.Duration) { + JitterUntilWithContext(ctx, f, period, 0.0, true) +} + +// NonSlidingUntil loops until stop channel is closed, running f every +// period. +// +// NonSlidingUntil is syntactic sugar on top of JitterUntil with zero jitter +// factor, with sliding = false (meaning the timer for period starts at the same +// time as the function starts). +func NonSlidingUntil(f func(), period time.Duration, stopCh <-chan struct{}) { + JitterUntil(f, period, 0.0, false, stopCh) +} + +// NonSlidingUntilWithContext loops until context is done, running f every +// period. +// +// NonSlidingUntilWithContext is syntactic sugar on top of JitterUntilWithContext +// with zero jitter factor, with sliding = false (meaning the timer for period +// starts at the same time as the function starts). +func NonSlidingUntilWithContext(ctx context.Context, f func(context.Context), period time.Duration) { + JitterUntilWithContext(ctx, f, period, 0.0, false) +} + +// JitterUntil loops until stop channel is closed, running f every period. +// +// If jitterFactor is positive, the period is jittered before every run of f. +// If jitterFactor is not positive, the period is unchanged and not jittered. +// +// If sliding is true, the period is computed after f runs. If it is false then +// period includes the runtime for f. +// +// Close stopCh to stop. f may not be invoked if stop channel is already +// closed. Pass NeverStop to if you don't want it stop. +func JitterUntil(f func(), period time.Duration, jitterFactor float64, sliding bool, stopCh <-chan struct{}) { + BackoffUntil(f, NewJitteredBackoffManager(period, jitterFactor, &clock.RealClock{}), sliding, stopCh) +} + +// BackoffUntil loops until stop channel is closed, run f every duration given by BackoffManager. +// +// If sliding is true, the period is computed after f runs. If it is false then +// period includes the runtime for f. +func BackoffUntil(f func(), backoff BackoffManager, sliding bool, stopCh <-chan struct{}) { + var t clock.Timer + for { + select { + case <-stopCh: + return + default: + } + + if !sliding { + t = backoff.Backoff() + } + + func() { + // defer runtime.HandleCrash() + f() + }() + + if sliding { + t = backoff.Backoff() + } + + // NOTE: b/c there is no priority selection in golang + // it is possible for this to race, meaning we could + // trigger t.C and stopCh, and t.C select falls through. + // In order to mitigate we re-check stopCh at the beginning + // of every loop to prevent extra executions of f(). + select { + case <-stopCh: + if !t.Stop() { + <-t.C() + } + return + case <-t.C(): + } + } +} + +// JitterUntilWithContext loops until context is done, running f every period. +// +// If jitterFactor is positive, the period is jittered before every run of f. +// If jitterFactor is not positive, the period is unchanged and not jittered. +// +// If sliding is true, the period is computed after f runs. If it is false then +// period includes the runtime for f. +// +// Cancel context to stop. f may not be invoked if context is already expired. +func JitterUntilWithContext(ctx context.Context, f func(context.Context), period time.Duration, jitterFactor float64, sliding bool) { + JitterUntil(func() { f(ctx) }, period, jitterFactor, sliding, ctx.Done()) +} + +// backoffManager provides simple backoff behavior in a threadsafe manner to a caller. +type backoffManager struct { + backoff Backoff + initialBackoff Backoff + resetInterval time.Duration + + clock clock.Clock + + lock sync.Mutex + lastStart time.Time + timer clock.Timer +} + +// Step returns the expected next duration to wait. +func (b *backoffManager) Step() time.Duration { + b.lock.Lock() + defer b.lock.Unlock() + + switch { + case b.resetInterval == 0: + b.backoff = b.initialBackoff + case b.clock.Now().Sub(b.lastStart) > b.resetInterval: + b.backoff = b.initialBackoff + b.lastStart = b.clock.Now() + } + return b.backoff.Step() +} + +// Backoff implements BackoffManager.Backoff, it returns a timer so caller can block on the timer +// for exponential backoff. The returned timer must be drained before calling Backoff() the second +// time. +func (b *backoffManager) Backoff() clock.Timer { + b.lock.Lock() + defer b.lock.Unlock() + if b.timer == nil { + b.timer = b.clock.NewTimer(b.Step()) + } else { + b.timer.Reset(b.Step()) + } + return b.timer +} + +// Timer returns a new Timer instance that shares the clock and the reset behavior with all other +// timers. +func (b *backoffManager) Timer() Timer { + return DelayFunc(b.Step).Timer(b.clock) +} + +// BackoffManager manages backoff with a particular scheme based on its underlying implementation. +type BackoffManager interface { + // Backoff returns a shared clock.Timer that is Reset on every invocation. This method is not + // safe for use from multiple threads. It returns a timer for backoff, and caller shall backoff + // until Timer.C() drains. If the second Backoff() is called before the timer from the first + // Backoff() call finishes, the first timer will NOT be drained and result in undetermined + // behavior. + Backoff() clock.Timer +} + +// Deprecated: Will be removed when the legacy polling functions are removed. +type exponentialBackoffManagerImpl struct { + backoff *Backoff + backoffTimer clock.Timer + lastBackoffStart time.Time + initialBackoff time.Duration + backoffResetDuration time.Duration + clock clock.Clock +} + +// NewExponentialBackoffManager returns a manager for managing exponential backoff. Each backoff is jittered and +// backoff will not exceed the given max. If the backoff is not called within resetDuration, the backoff is reset. +// This backoff manager is used to reduce load during upstream unhealthiness. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct, use DelayWithReset() to get a DelayFunc that periodically resets itself, and then +// invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewExponentialBackoffManager(init, max, reset, factor, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// delayFn := wait.Backoff{ +// Duration: init, +// Cap: max, +// Steps: int(math.Ceil(float64(max) / float64(init))), // now a required argument +// Factor: factor, +// Jitter: jitter, +// }.DelayWithReset(reset, clock) +// wait.BackoffUntil(..., delayFn.Timer(), ...) +func NewExponentialBackoffManager(initBackoff, maxBackoff, resetDuration time.Duration, backoffFactor, jitter float64, c clock.Clock) BackoffManager { + return &exponentialBackoffManagerImpl{ + backoff: &Backoff{ + Duration: initBackoff, + Factor: backoffFactor, + Jitter: jitter, + + // the current impl of wait.Backoff returns Backoff.Duration once steps are used up, which is not + // what we ideally need here, we set it to max int and assume we will never use up the steps + Steps: math.MaxInt32, + Cap: maxBackoff, + }, + backoffTimer: nil, + initialBackoff: initBackoff, + lastBackoffStart: c.Now(), + backoffResetDuration: resetDuration, + clock: c, + } +} + +func (b *exponentialBackoffManagerImpl) getNextBackoff() time.Duration { + if b.clock.Now().Sub(b.lastBackoffStart) > b.backoffResetDuration { + b.backoff.Steps = math.MaxInt32 + b.backoff.Duration = b.initialBackoff + } + b.lastBackoffStart = b.clock.Now() + return b.backoff.Step() +} + +// Backoff implements BackoffManager.Backoff, it returns a timer so caller can block on the timer for exponential backoff. +// The returned timer must be drained before calling Backoff() the second time +func (b *exponentialBackoffManagerImpl) Backoff() clock.Timer { + if b.backoffTimer == nil { + b.backoffTimer = b.clock.NewTimer(b.getNextBackoff()) + } else { + b.backoffTimer.Reset(b.getNextBackoff()) + } + return b.backoffTimer +} + +// Deprecated: Will be removed when the legacy polling functions are removed. +type jitteredBackoffManagerImpl struct { + clock clock.Clock + duration time.Duration + jitter float64 + backoffTimer clock.Timer +} + +// NewJitteredBackoffManager returns a BackoffManager that backoffs with given duration plus given jitter. If the jitter +// is negative, backoff will not be jittered. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct and invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewJitteredBackoffManager(duration, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// wait.BackoffUntil(..., wait.Backoff{Duration: duration, Jitter: jitter}.Timer(), ...) +func NewJitteredBackoffManager(duration time.Duration, jitter float64, c clock.Clock) BackoffManager { + return &jitteredBackoffManagerImpl{ + clock: c, + duration: duration, + jitter: jitter, + backoffTimer: nil, + } +} + +func (j *jitteredBackoffManagerImpl) getNextBackoff() time.Duration { + jitteredPeriod := j.duration + if j.jitter > 0.0 { + jitteredPeriod = Jitter(j.duration, j.jitter) + } + return jitteredPeriod +} + +// Backoff implements BackoffManager.Backoff, it returns a timer so caller can block on the timer for jittered backoff. +// The returned timer must be drained before calling Backoff() the second time +func (j *jitteredBackoffManagerImpl) Backoff() clock.Timer { + backoff := j.getNextBackoff() + if j.backoffTimer == nil { + j.backoffTimer = j.clock.NewTimer(backoff) + } else { + j.backoffTimer.Reset(backoff) + } + return j.backoffTimer +} + +// ExponentialBackoff repeats a condition check with exponential backoff. +// +// It repeatedly checks the condition and then sleeps, using `backoff.Step()` +// to determine the length of the sleep and adjust Duration and Steps. +// Stops and returns as soon as: +// 1. the condition check returns true or an error, +// 2. `backoff.Steps` checks of the condition have been done, or +// 3. a sleep truncated by the cap on duration has been completed. +// In case (1) the returned error is what the condition function returned. +// In all other cases, ErrWaitTimeout is returned. +// +// Since backoffs are often subject to cancellation, we recommend using +// ExponentialBackoffWithContext and passing a context to the method. +func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { + for backoff.Steps > 0 { + if ok, err := runConditionWithCrashProtection(condition); err != nil || ok { + return err + } + if backoff.Steps == 1 { + break + } + time.Sleep(backoff.Step()) + } + return ErrWaitTimeout +} + +// ExponentialBackoffWithContext repeats a condition check with exponential backoff. +// It immediately returns an error if the condition returns an error, the context is cancelled +// or hits the deadline, or if the maximum attempts defined in backoff is exceeded (ErrWaitTimeout). +// If an error is returned by the condition the backoff stops immediately. The condition will +// never be invoked more than backoff.Steps times. +func ExponentialBackoffWithContext(ctx context.Context, backoff Backoff, condition ConditionWithContextFunc) error { + for backoff.Steps > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if ok, err := runConditionWithCrashProtectionWithContext(ctx, condition); err != nil || ok { + return err + } + + if backoff.Steps == 1 { + break + } + + waitBeforeRetry := backoff.Step() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitBeforeRetry): + } + } + + return ErrWaitTimeout +} diff --git a/proxy/wireguard/iptables/wait/clock/clock.go b/proxy/wireguard/iptables/wait/clock/clock.go new file mode 100644 index 000000000000..dd181ce8d8b0 --- /dev/null +++ b/proxy/wireguard/iptables/wait/clock/clock.go @@ -0,0 +1,168 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package clock + +import "time" + +// PassiveClock allows for injecting fake or real clocks into code +// that needs to read the current time but does not support scheduling +// activity in the future. +type PassiveClock interface { + Now() time.Time + Since(time.Time) time.Duration +} + +// Clock allows for injecting fake or real clocks into code that +// needs to do arbitrary things based on time. +type Clock interface { + PassiveClock + // After returns the channel of a new Timer. + // This method does not allow to free/GC the backing timer before it fires. Use + // NewTimer instead. + After(d time.Duration) <-chan time.Time + // NewTimer returns a new Timer. + NewTimer(d time.Duration) Timer + // Sleep sleeps for the provided duration d. + // Consider making the sleep interruptible by using 'select' on a context channel and a timer channel. + Sleep(d time.Duration) + // Tick returns the channel of a new Ticker. + // This method does not allow to free/GC the backing ticker. Use + // NewTicker from WithTicker instead. + Tick(d time.Duration) <-chan time.Time +} + +// WithTicker allows for injecting fake or real clocks into code that +// needs to do arbitrary things based on time. +type WithTicker interface { + Clock + // NewTicker returns a new Ticker. + NewTicker(time.Duration) Ticker +} + +// WithDelayedExecution allows for injecting fake or real clocks into +// code that needs to make use of AfterFunc functionality. +type WithDelayedExecution interface { + Clock + // AfterFunc executes f in its own goroutine after waiting + // for d duration and returns a Timer whose channel can be + // closed by calling Stop() on the Timer. + AfterFunc(d time.Duration, f func()) Timer +} + +// Ticker defines the Ticker interface. +type Ticker interface { + C() <-chan time.Time + Stop() +} + +var _ = WithTicker(RealClock{}) + +// RealClock really calls time.Now() +type RealClock struct{} + +// Now returns the current time. +func (RealClock) Now() time.Time { + return time.Now() +} + +// Since returns time since the specified timestamp. +func (RealClock) Since(ts time.Time) time.Duration { + return time.Since(ts) +} + +// After is the same as time.After(d). +// This method does not allow to free/GC the backing timer before it fires. Use +// NewTimer instead. +func (RealClock) After(d time.Duration) <-chan time.Time { + return time.After(d) +} + +// NewTimer is the same as time.NewTimer(d) +func (RealClock) NewTimer(d time.Duration) Timer { + return &realTimer{ + timer: time.NewTimer(d), + } +} + +// AfterFunc is the same as time.AfterFunc(d, f). +func (RealClock) AfterFunc(d time.Duration, f func()) Timer { + return &realTimer{ + timer: time.AfterFunc(d, f), + } +} + +// Tick is the same as time.Tick(d) +// This method does not allow to free/GC the backing ticker. Use +// NewTicker instead. +func (RealClock) Tick(d time.Duration) <-chan time.Time { + return time.Tick(d) +} + +// NewTicker returns a new Ticker. +func (RealClock) NewTicker(d time.Duration) Ticker { + return &realTicker{ + ticker: time.NewTicker(d), + } +} + +// Sleep is the same as time.Sleep(d) +// Consider making the sleep interruptible by using 'select' on a context channel and a timer channel. +func (RealClock) Sleep(d time.Duration) { + time.Sleep(d) +} + +// Timer allows for injecting fake or real timers into code that +// needs to do arbitrary things based on time. +type Timer interface { + C() <-chan time.Time + Stop() bool + Reset(d time.Duration) bool +} + +var _ = Timer(&realTimer{}) + +// realTimer is backed by an actual time.Timer. +type realTimer struct { + timer *time.Timer +} + +// C returns the underlying timer's channel. +func (r *realTimer) C() <-chan time.Time { + return r.timer.C +} + +// Stop calls Stop() on the underlying timer. +func (r *realTimer) Stop() bool { + return r.timer.Stop() +} + +// Reset calls Reset() on the underlying timer. +func (r *realTimer) Reset(d time.Duration) bool { + return r.timer.Reset(d) +} + +type realTicker struct { + ticker *time.Ticker +} + +func (r *realTicker) C() <-chan time.Time { + return r.ticker.C +} + +func (r *realTicker) Stop() { + r.ticker.Stop() +} diff --git a/proxy/wireguard/iptables/wait/delay.go b/proxy/wireguard/iptables/wait/delay.go new file mode 100644 index 000000000000..b003c7f6076e --- /dev/null +++ b/proxy/wireguard/iptables/wait/delay.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "sync" + "time" + + "github.com/xtls/xray-core/proxy/wireguard/iptables/wait/clock" +) + +// DelayFunc returns the next time interval to wait. +type DelayFunc func() time.Duration + +// Timer takes an arbitrary delay function and returns a timer that can handle arbitrary interval changes. +// Use Backoff{...}.Timer() for simple delays and more efficient timers. +func (fn DelayFunc) Timer(c clock.Clock) Timer { + return &variableTimer{fn: fn, new: c.NewTimer} +} + +// Until takes an arbitrary delay function and runs until cancelled or the condition indicates exit. This +// offers all of the functionality of the methods in this package. +func (fn DelayFunc) Until(ctx context.Context, immediate, sliding bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, &variableTimer{fn: fn, new: internalClock.NewTimer}, immediate, sliding, condition) +} + +// Concurrent returns a version of this DelayFunc that is safe for use by multiple goroutines that +// wish to share a single delay timer. +func (fn DelayFunc) Concurrent() DelayFunc { + var lock sync.Mutex + return func() time.Duration { + lock.Lock() + defer lock.Unlock() + return fn() + } +} diff --git a/proxy/wireguard/iptables/wait/error.go b/proxy/wireguard/iptables/wait/error.go new file mode 100644 index 000000000000..dd75801d829e --- /dev/null +++ b/proxy/wireguard/iptables/wait/error.go @@ -0,0 +1,96 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "errors" +) + +// ErrWaitTimeout is returned when the condition was not satisfied in time. +// +// Deprecated: This type will be made private in favor of Interrupted() +// for checking errors or ErrorInterrupted(err) for returning a wrapped error. +var ErrWaitTimeout = ErrorInterrupted(errors.New("timed out waiting for the condition")) + +// Interrupted returns true if the error indicates a Poll, ExponentialBackoff, or +// Until loop exited for any reason besides the condition returning true or an +// error. A loop is considered interrupted if the calling context is cancelled, +// the context reaches its deadline, or a backoff reaches its maximum allowed +// steps. +// +// Callers should use this method instead of comparing the error value directly to +// ErrWaitTimeout, as methods that cancel a context may not return that error. +// +// Instead of: +// +// err := wait.Poll(...) +// if err == wait.ErrWaitTimeout { +// log.Infof("Wait for operation exceeded") +// } else ... +// +// Use: +// +// err := wait.Poll(...) +// if wait.Interrupted(err) { +// log.Infof("Wait for operation exceeded") +// } else ... +func Interrupted(err error) bool { + switch { + case errors.Is(err, errWaitTimeout), + errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } +} + +// errInterrupted +type errInterrupted struct { + cause error +} + +// ErrorInterrupted returns an error that indicates the wait was ended +// early for a given reason. If no cause is provided a generic error +// will be used but callers are encouraged to provide a real cause for +// clarity in debugging. +func ErrorInterrupted(cause error) error { + switch cause.(type) { + case errInterrupted: + // no need to wrap twice since errInterrupted is only needed + // once in a chain + return cause + default: + return errInterrupted{cause} + } +} + +// errWaitTimeout is the private version of the previous ErrWaitTimeout +// and is private to prevent direct comparison. Use ErrorInterrupted(err) +// to get an error that will return true for Interrupted(err). +var errWaitTimeout = errInterrupted{} + +func (e errInterrupted) Unwrap() error { return e.cause } +func (e errInterrupted) Is(target error) bool { return target == errWaitTimeout } +func (e errInterrupted) Error() string { + if e.cause == nil { + // returns the same error message as historical behavior + return "timed out waiting for the condition" + } + return e.cause.Error() +} diff --git a/proxy/wireguard/iptables/wait/loop.go b/proxy/wireguard/iptables/wait/loop.go new file mode 100644 index 000000000000..44ca1d46c08c --- /dev/null +++ b/proxy/wireguard/iptables/wait/loop.go @@ -0,0 +1,94 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "time" + // "k8s.io/apimachinery/pkg/util/runtime" +) + +// loopConditionUntilContext executes the provided condition at intervals defined by +// the provided timer until the provided context is cancelled, the condition returns +// true, or the condition returns an error. If sliding is true, the period is computed +// after condition runs. If it is false then period includes the runtime for condition. +// If immediate is false the first delay happens before any call to condition, if +// immediate is true the condition will be invoked before waiting and guarantees that +// the condition is invoked at least once, regardless of whether the context has been +// cancelled. The returned error is the error returned by the last condition or the +// context error if the context was terminated. +// +// This is the common loop construct for all polling in the wait package. +func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding bool, condition ConditionWithContextFunc) error { + defer t.Stop() + + var timeCh <-chan time.Time + doneCh := ctx.Done() + + if !sliding { + timeCh = t.C() + } + + // if immediate is true the condition is + // guaranteed to be executed at least once, + // if we haven't requested immediate execution, delay once + if immediate { + if ok, err := func() (bool, error) { + // defer runtime.HandleCrash() + return condition(ctx) + }(); err != nil || ok { + return err + } + } + + if sliding { + timeCh = t.C() + } + + for { + + // Wait for either the context to be cancelled or the next invocation be called + select { + case <-doneCh: + return ctx.Err() + case <-timeCh: + } + + // IMPORTANT: Because there is no channel priority selection in golang + // it is possible for very short timers to "win" the race in the previous select + // repeatedly even when the context has been canceled. We therefore must + // explicitly check for context cancellation on every loop and exit if true to + // guarantee that we don't invoke condition more than once after context has + // been cancelled. + if err := ctx.Err(); err != nil { + return err + } + + if !sliding { + t.Next() + } + if ok, err := func() (bool, error) { + // defer runtime.HandleCrash() + return condition(ctx) + }(); err != nil || ok { + return err + } + if sliding { + t.Next() + } + } +} diff --git a/proxy/wireguard/iptables/wait/poll.go b/proxy/wireguard/iptables/wait/poll.go new file mode 100644 index 000000000000..231d4c384239 --- /dev/null +++ b/proxy/wireguard/iptables/wait/poll.go @@ -0,0 +1,315 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "time" +) + +// PollUntilContextCancel tries a condition func until it returns true, an error, or the context +// is cancelled or hits a deadline. condition will be invoked after the first interval if the +// context is not cancelled first. The returned error will be from ctx.Err(), the condition's +// err return value, or nil. If invoking condition takes longer than interval the next condition +// will be invoked immediately. When using very short intervals, condition may be invoked multiple +// times before a context cancellation is detected. If immediate is true, condition will be +// invoked before waiting and guarantees that condition is invoked at least once, regardless of +// whether the context has been cancelled. +func PollUntilContextCancel(ctx context.Context, interval time.Duration, immediate bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + +// PollUntilContextTimeout will terminate polling after timeout duration by setting a context +// timeout. This is provided as a convenience function for callers not currently executing under +// a deadline and is equivalent to: +// +// deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) +// err := PollUntilContextCancel(deadlineCtx, interval, immediate, condition) +// +// The deadline context will be cancelled if the Poll succeeds before the timeout, simplifying +// inline usage. All other behavior is identical to PollUntilContextCancel. +func PollUntilContextTimeout(ctx context.Context, interval, timeout time.Duration, immediate bool, condition ConditionWithContextFunc) error { + deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) + defer deadlineCancel() + return loopConditionUntilContext(deadlineCtx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + +// Poll tries a condition func until it returns true, an error, or the timeout +// is reached. +// +// Poll always waits the interval before the run of 'condition'. +// 'condition' will always be invoked at least once. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollUntilContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func Poll(interval, timeout time.Duration, condition ConditionFunc) error { + return PollWithContext(context.Background(), interval, timeout, condition.WithContext()) +} + +// PollWithContext tries a condition func until it returns true, an error, +// or when the context expires or the timeout is reached, whichever +// happens first. +// +// PollWithContext always waits the interval before the run of 'condition'. +// 'condition' will always be invoked at least once. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollUntilContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, timeout), condition) +} + +// PollUntil tries a condition func until it returns true, an error or stopCh is +// closed. +// +// PollUntil always waits interval before the first run of 'condition'. +// 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { + return PollUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) +} + +// PollUntilWithContext tries a condition func until it returns true, +// an error or the specified context is cancelled or expired. +// +// PollUntilWithContext always waits interval before the first run of 'condition'. +// 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, 0), condition) +} + +// PollInfinite tries a condition func until it returns true or an error +// +// PollInfinite always waits the interval before the run of 'condition'. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollInfinite(interval time.Duration, condition ConditionFunc) error { + return PollInfiniteWithContext(context.Background(), interval, condition.WithContext()) +} + +// PollInfiniteWithContext tries a condition func until it returns true or an error +// +// PollInfiniteWithContext always waits the interval before the run of 'condition'. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, 0), condition) +} + +// PollImmediate tries a condition func until it returns true, an error, or the timeout +// is reached. +// +// PollImmediate always checks 'condition' before waiting for the interval. 'condition' +// will always be invoked at least once. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollUntilContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { + return PollImmediateWithContext(context.Background(), interval, timeout, condition.WithContext()) +} + +// PollImmediateWithContext tries a condition func until it returns true, an error, +// or the timeout is reached or the specified context expires, whichever happens first. +// +// PollImmediateWithContext always checks 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollUntilContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediateWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, timeout), condition) +} + +// PollImmediateUntil tries a condition func until it returns true, an error or stopCh is closed. +// +// PollImmediateUntil runs the 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { + return PollImmediateUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) +} + +// PollImmediateUntilWithContext tries a condition func until it returns true, +// an error or the specified context is cancelled or expired. +// +// PollImmediateUntilWithContext runs the 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, 0), condition) +} + +// PollImmediateInfinite tries a condition func until it returns true or an error +// +// PollImmediateInfinite runs the 'condition' before waiting for the interval. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) error { + return PollImmediateInfiniteWithContext(context.Background(), interval, condition.WithContext()) +} + +// PollImmediateInfiniteWithContext tries a condition func until it returns true +// or an error or the specified context gets cancelled or expired. +// +// PollImmediateInfiniteWithContext runs the 'condition' before waiting for the interval. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// Deprecated: This method does not return errors from context, use PollUntilContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. +func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, 0), condition) +} + +// Internally used, each of the public 'Poll*' function defined in this +// package should invoke this internal function with appropriate parameters. +// ctx: the context specified by the caller, for infinite polling pass +// a context that never gets cancelled or expired. +// immediate: if true, the 'condition' will be invoked before waiting for the interval, +// in this case 'condition' will always be invoked at least once. +// wait: user specified WaitFunc function that controls at what interval the condition +// function should be invoked periodically and whether it is bound by a timeout. +// condition: user specified ConditionWithContextFunc function. +// +// Deprecated: will be removed in favor of loopConditionUntilContext. +func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, condition ConditionWithContextFunc) error { + if immediate { + done, err := runConditionWithCrashProtectionWithContext(ctx, condition) + if err != nil { + return err + } + if done { + return nil + } + } + + select { + case <-ctx.Done(): + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead + return ErrWaitTimeout + default: + return waitForWithContext(ctx, wait, condition) + } +} + +// poller returns a WaitFunc that will send to the channel every interval until +// timeout has elapsed and then closes the channel. +// +// Over very short intervals you may receive no ticks before the channel is +// closed. A timeout of 0 is interpreted as an infinity, and in such a case +// it would be the caller's responsibility to close the done channel. +// Failure to do so would result in a leaked goroutine. +// +// Output ticks are not buffered. If the channel is not ready to receive an +// item, the tick is skipped. +// +// Deprecated: Will be removed in a future release. +func poller(interval, timeout time.Duration) waitWithContextFunc { + return waitWithContextFunc(func(ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) + + go func() { + defer close(ch) + + tick := time.NewTicker(interval) + defer tick.Stop() + + var after <-chan time.Time + if timeout != 0 { + // time.After is more convenient, but it + // potentially leaves timers around much longer + // than necessary if we exit early. + timer := time.NewTimer(timeout) + after = timer.C + defer timer.Stop() + } + + for { + select { + case <-tick.C: + // If the consumer isn't ready for this signal drop it and + // check the other channels. + select { + case ch <- struct{}{}: + default: + } + case <-after: + return + case <-ctx.Done(): + return + } + } + }() + + return ch + }) +} diff --git a/proxy/wireguard/iptables/wait/timer.go b/proxy/wireguard/iptables/wait/timer.go new file mode 100644 index 000000000000..80ba82c2ff6f --- /dev/null +++ b/proxy/wireguard/iptables/wait/timer.go @@ -0,0 +1,121 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "time" + + "github.com/xtls/xray-core/proxy/wireguard/iptables/wait/clock" +) + +// Timer abstracts how wait functions interact with time runtime efficiently. Test +// code may implement this interface directly but package consumers are encouraged +// to use the Backoff type as the primary mechanism for acquiring a Timer. The +// interface is a simplification of clock.Timer to prevent misuse. Timers are not +// expected to be safe for calls from multiple goroutines. +type Timer interface { + // C returns a channel that will receive a struct{} each time the timer fires. + // The channel should not be waited on after Stop() is invoked. It is allowed + // to cache the returned value of C() for the lifetime of the Timer. + C() <-chan time.Time + // Next is invoked by wait functions to signal timers that the next interval + // should begin. You may only use Next() if you have drained the channel C(). + // You should not call Next() after Stop() is invoked. + Next() + // Stop releases the timer. It is safe to invoke if no other methods have been + // called. + Stop() +} + +type noopTimer struct { + closedCh <-chan time.Time +} + +// newNoopTimer creates a timer with a unique channel to avoid contention +// for the channel's lock across multiple unrelated timers. +func newNoopTimer() noopTimer { + ch := make(chan time.Time) + close(ch) + return noopTimer{closedCh: ch} +} + +func (t noopTimer) C() <-chan time.Time { + return t.closedCh +} +func (noopTimer) Next() {} +func (noopTimer) Stop() {} + +type variableTimer struct { + fn DelayFunc + t clock.Timer + new func(time.Duration) clock.Timer +} + +func (t *variableTimer) C() <-chan time.Time { + if t.t == nil { + d := t.fn() + t.t = t.new(d) + } + return t.t.C() +} + +func (t *variableTimer) Next() { + if t.t == nil { + return + } + d := t.fn() + t.t.Reset(d) +} + +func (t *variableTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +type fixedTimer struct { + interval time.Duration + t clock.Ticker + new func(time.Duration) clock.Ticker +} + +func (t *fixedTimer) C() <-chan time.Time { + if t.t == nil { + t.t = t.new(t.interval) + } + return t.t.C() +} + +func (t *fixedTimer) Next() { + // no-op for fixed timers +} + +func (t *fixedTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +// RealTimer can be passed to methods that need a clock.Timer. +var RealTimer = clock.RealClock{}.NewTimer + +// internalClock is used for test injection of clocks +var internalClock = clock.RealClock{} diff --git a/proxy/wireguard/iptables/wait/wait.go b/proxy/wireguard/iptables/wait/wait.go new file mode 100644 index 000000000000..e6eaac555000 --- /dev/null +++ b/proxy/wireguard/iptables/wait/wait.go @@ -0,0 +1,222 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "math/rand" + "sync" + "time" + // "k8s.io/apimachinery/pkg/util/runtime" +) + +// For any test of the style: +// +// ... +// <- time.After(timeout): +// t.Errorf("Timed out") +// +// The value for timeout should effectively be "forever." Obviously we don't want our tests to truly lock up forever, but 30s +// is long enough that it is effectively forever for the things that can slow down a run on a heavily contended machine +// (GC, seeks, etc), but not so long as to make a developer ctrl-c a test run if they do happen to break that test. +var ForeverTestTimeout = time.Second * 30 + +// NeverStop may be passed to Until to make it never stop. +var NeverStop <-chan struct{} = make(chan struct{}) + +// Group allows to start a group of goroutines and wait for their completion. +type Group struct { + wg sync.WaitGroup +} + +func (g *Group) Wait() { + g.wg.Wait() +} + +// StartWithChannel starts f in a new goroutine in the group. +// stopCh is passed to f as an argument. f should stop when stopCh is available. +func (g *Group) StartWithChannel(stopCh <-chan struct{}, f func(stopCh <-chan struct{})) { + g.Start(func() { + f(stopCh) + }) +} + +// StartWithContext starts f in a new goroutine in the group. +// ctx is passed to f as an argument. f should stop when ctx.Done() is available. +func (g *Group) StartWithContext(ctx context.Context, f func(context.Context)) { + g.Start(func() { + f(ctx) + }) +} + +// Start starts f in a new goroutine in the group. +func (g *Group) Start(f func()) { + g.wg.Add(1) + go func() { + defer g.wg.Done() + f() + }() +} + +// Forever calls f every period for ever. +// +// Forever is syntactic sugar on top of Until. +func Forever(f func(), period time.Duration) { + Until(f, period, NeverStop) +} + +// Jitter returns a time.Duration between duration and duration + maxFactor * +// duration. +// +// This allows clients to avoid converging on periodic behavior. If maxFactor +// is 0.0, a suggested default value will be chosen. +func Jitter(duration time.Duration, maxFactor float64) time.Duration { + if maxFactor <= 0.0 { + maxFactor = 1.0 + } + wait := duration + time.Duration(rand.Float64()*maxFactor*float64(duration)) + return wait +} + +// ConditionFunc returns true if the condition is satisfied, or an error +// if the loop should be aborted. +type ConditionFunc func() (done bool, err error) + +// ConditionWithContextFunc returns true if the condition is satisfied, or an error +// if the loop should be aborted. +// +// The caller passes along a context that can be used by the condition function. +type ConditionWithContextFunc func(context.Context) (done bool, err error) + +// WithContext converts a ConditionFunc into a ConditionWithContextFunc +func (cf ConditionFunc) WithContext() ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + return cf() + } +} + +// ContextForChannel provides a context that will be treated as cancelled +// when the provided parentCh is closed. The implementation returns +// context.Canceled for Err() if and only if the parentCh is closed. +func ContextForChannel(parentCh <-chan struct{}) context.Context { + return channelContext{stopCh: parentCh} +} + +var _ context.Context = channelContext{} + +// channelContext will behave as if the context were cancelled when stopCh is +// closed. +type channelContext struct { + stopCh <-chan struct{} +} + +func (c channelContext) Done() <-chan struct{} { return c.stopCh } +func (c channelContext) Err() error { + select { + case <-c.stopCh: + return context.Canceled + default: + return nil + } +} +func (c channelContext) Deadline() (time.Time, bool) { return time.Time{}, false } +func (c channelContext) Value(key any) any { return nil } + +// runConditionWithCrashProtection runs a ConditionFunc with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. +func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) { + // defer runtime.HandleCrash() + return condition() +} + +// runConditionWithCrashProtectionWithContext runs a ConditionWithContextFunc +// with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. +func runConditionWithCrashProtectionWithContext(ctx context.Context, condition ConditionWithContextFunc) (bool, error) { + // defer runtime.HandleCrash() + return condition(ctx) +} + +// waitFunc creates a channel that receives an item every time a test +// should be executed and is closed when the last test should be invoked. +// +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. +type waitFunc func(done <-chan struct{}) <-chan struct{} + +// WithContext converts the WaitFunc to an equivalent WaitWithContextFunc +func (w waitFunc) WithContext() waitWithContextFunc { + return func(ctx context.Context) <-chan struct{} { + return w(ctx.Done()) + } +} + +// waitWithContextFunc creates a channel that receives an item every time a test +// should be executed and is closed when the last test should be invoked. +// +// When the specified context gets cancelled or expires the function +// stops sending item and returns immediately. +// +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. +type waitWithContextFunc func(ctx context.Context) <-chan struct{} + +// waitForWithContext continually checks 'fn' as driven by 'wait'. +// +// waitForWithContext gets a channel from 'wait()”, and then invokes 'fn' +// once for every value placed on the channel and once more when the +// channel is closed. If the channel is closed and 'fn' +// returns false without error, waitForWithContext returns ErrWaitTimeout. +// +// If 'fn' returns an error the loop ends and that error is returned. If +// 'fn' returns true the loop ends and nil is returned. +// +// context.Canceled will be returned if the ctx.Done() channel is closed +// without fn ever returning true. +// +// When the ctx.Done() channel is closed, because the golang `select` statement is +// "uniform pseudo-random", the `fn` might still run one or multiple times, +// though eventually `waitForWithContext` will return. +// +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. +func waitForWithContext(ctx context.Context, wait waitWithContextFunc, fn ConditionWithContextFunc) error { + waitCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := wait(waitCtx) + for { + select { + case _, open := <-c: + ok, err := runConditionWithCrashProtectionWithContext(ctx, fn) + if err != nil { + return err + } + if ok { + return nil + } + if !open { + return ErrWaitTimeout + } + case <-ctx.Done(): + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead + return ErrWaitTimeout + } + } +} diff --git a/proxy/wireguard/tun_default.go b/proxy/wireguard/tun_kernel_default.go similarity index 58% rename from proxy/wireguard/tun_default.go rename to proxy/wireguard/tun_kernel_default.go index 4d0567af0029..61836ebd371f 100644 --- a/proxy/wireguard/tun_default.go +++ b/proxy/wireguard/tun_kernel_default.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !linux package wireguard @@ -8,9 +8,5 @@ import ( ) func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { - return nil, errors.New("not implemented") -} - -func KernelTunSupported() bool { - return false + return nil, errors.New("not implemented kernel tunnel for non-linux system") } diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_kernel_linux.go similarity index 73% rename from proxy/wireguard/tun_linux.go rename to proxy/wireguard/tun_kernel_linux.go index b85a9d097e5c..c03ac2848675 100644 --- a/proxy/wireguard/tun_linux.go +++ b/proxy/wireguard/tun_kernel_linux.go @@ -1,4 +1,4 @@ -//go:build linux && !android +//go:build linux package wireguard @@ -8,12 +8,13 @@ import ( "fmt" "net" "net/netip" - "os" "golang.org/x/sys/unix" "github.com/sagernet/sing/common/control" "github.com/vishvananda/netlink" + "github.com/xtls/xray-core/proxy/wireguard/iptables" + iptexec "github.com/xtls/xray-core/proxy/wireguard/iptables/exec" wgtun "golang.zx2c4.com/wireguard/tun" ) @@ -25,6 +26,10 @@ type deviceNet struct { linkAddrs []netlink.Addr routes []*netlink.Route rules []*netlink.Rule + + ipt iptables.Interface + + iptManglePreRoutingRules [][]string } func newDeviceNet(interfaceName string) *deviceNet { @@ -82,38 +87,12 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo x := prefixes v4 = &x } - if v6 == nil && prefixes.Is6() { + if v6 == nil && prefixes.Is6() && CheckUnixKernelIPv6IsEnabled() { x := prefixes v6 = &x } } - writeSysctlZero := func(path string) error { - _, err := os.Stat(path) - if os.IsNotExist(err) { - return nil - } - if err != nil { - return err - } - return os.WriteFile(path, []byte("0"), 0o644) - } - - // system configs. - if v4 != nil { - if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err) - } - } - if v6 != nil { - if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil { - return nil, fmt.Errorf("failed to enable ipv6: %w", err) - } - if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err) - } - } - n := CalculateInterfaceName("wg") wgt, err := wgtun.CreateTUN(n, mtu) if err != nil { @@ -125,16 +104,18 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo } }() - // disable linux rp_filter for tunnel device to avoid packet drop. - // the operation require root privilege on container require '--privileged' flag. + ipv4TableIndex := 1023 if v4 != nil { - if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err) - } - } - if v6 != nil { - if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err) + r := &netlink.Route{Table: ipv4TableIndex} + for { + routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V4, r, netlink.RT_FILTER_TABLE) + if len(routeList) == 0 || fErr != nil { + break + } + ipv4TableIndex-- + if ipv4TableIndex < 0 { + return nil, fmt.Errorf("failed to find available ipv4 table index") + } } } @@ -164,6 +145,11 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo } }() + out.ipt = iptables.New(iptexec.New(), iptables.ProtocolIPv4) + if exist := out.ipt.Present(); !exist { + return nil, fmt.Errorf("iptables is not available") + } + l, err := netlink.LinkByName(n) if err != nil { return nil, err @@ -177,6 +163,25 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo }, } out.linkAddrs = append(out.linkAddrs, addr) + + rt := &netlink.Route{ + LinkIndex: l.Attrs().Index, + Dst: &net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + }, + Table: ipv4TableIndex, + } + out.routes = append(out.routes, rt) + + r := netlink.NewRule() + r.Table, r.Family, r.Mark = ipv4TableIndex, unix.AF_INET, ipv4TableIndex + out.rules = append(out.rules, r) + + // -i wg0 -j MARK --set-xmark 0x334/0xffffffff + out.iptManglePreRoutingRules = append(out.iptManglePreRoutingRules, []string{ + "-i", n, "-j", "MARK", "--set-xmark", fmt.Sprintf("0x%x/0xffffffff", ipv4TableIndex), + }) } if v6 != nil { addr := netlink.Addr{ @@ -224,14 +229,13 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo return nil, fmt.Errorf("failed to add rule %s: %w", rule, err) } } + for _, rule := range out.iptManglePreRoutingRules { + _, err = out.ipt.EnsureRule(iptables.Append, iptables.TableMangle, + iptables.ChainPrerouting, rule...) + if err != nil { + return nil, fmt.Errorf("failed to add rule %s: %w", rule, err) + } + } out.tun = wgt return out, nil } - -func KernelTunSupported() bool { - // run a superuser permission check to check - // if the current user has the sufficient permission - // to create a tun device. - - return unix.Geteuid() == 0 // 0 means root -} diff --git a/proxy/wireguard/wireguard_linux.go b/proxy/wireguard/wireguard_linux.go new file mode 100644 index 000000000000..2d84b0f1d5b5 --- /dev/null +++ b/proxy/wireguard/wireguard_linux.go @@ -0,0 +1,61 @@ +//go:build linux + +package wireguard + +import ( + "os" + "os/exec" + "strings" + + "kernel.org/pub/linux/libs/security/libcap/cap" +) + +func IsLinux() bool { + return true +} + +func CheckUnixKernelTunDeviceEnabled() bool { + if _, err := os.Stat("/dev/net/tun"); err != nil { + return false + } + return true +} + +func CheckUnixKernelNetAdminCapEnabled() bool { + orig := cap.GetProc() + c, err := orig.Dup() + if err != nil { + return false + } + on, _ := c.GetFlag(cap.Effective, cap.NET_ADMIN) + return on +} + +func CheckUnixKernelIPv4SrcValidMarkEnabled() bool { + buf, _ := os.ReadFile("/proc/sys/net/ipv4/conf/all/src_valid_mark") + value := strings.TrimSpace(string(buf)) + return value == "1" +} + +func CheckUnixKernelIPv6IsEnabled() bool { + buf, _ := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6") + value := strings.TrimSpace(string(buf)) + return value == "0" +} + +// CheckUnixKernelTunSupported returns true if kernel tun is supported. +// 1. check if the current process has CAP_NET_ADMIN capability +// 2. check if /proc/sys/net/ipv4/conf/all/src_valid_mark exists and is set to 1 +// 3. check if iptables is available +func CheckUnixKernelTunSupported() bool { + if !CheckUnixKernelTunDeviceEnabled() || !CheckUnixKernelNetAdminCapEnabled() { + return false + } + outCmd := exec.Command("sh", "-c", "command -v iptables") + outBuffer, err := outCmd.CombinedOutput() + if err != nil { + return false + } + iptablesPath := strings.TrimSpace(string(outBuffer)) + return iptablesPath != "" +} diff --git a/proxy/wireguard/wireguard_others.go b/proxy/wireguard/wireguard_others.go new file mode 100644 index 000000000000..9dd8bef1890e --- /dev/null +++ b/proxy/wireguard/wireguard_others.go @@ -0,0 +1,27 @@ +//go:build !linux + +package wireguard + +func IsLinux() bool { + return false +} + +func CheckUnixKernelTunDeviceEnabled() bool { + return true +} + +func CheckUnixKernelNetAdminCapEnabled() bool { + return false +} + +func CheckUnixKernelIPv6IsEnabled() bool { + return false +} + +func CheckUnixKernelIPv4SrcValidMarkEnabled() bool { + return false +} + +func CheckUnixKernelTunSupported() bool { + return false +} From 25e59ce61a66a8a7ac3913041de7fe12eb65f8c5 Mon Sep 17 00:00:00 2001 From: kunsonxs Date: Mon, 27 Nov 2023 11:08:01 +0800 Subject: [PATCH 2/3] fix : wireguard kernel mode cleanup iptables rules --- proxy/wireguard/tun.go | 1 + proxy/wireguard/tun_kernel_linux.go | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index c2d303236494..01a8b96f06d3 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -37,6 +37,7 @@ type Tunnel interface { Close() error } +// tunnel is a wrapper of wireguard device and tun device for gvisorNet and deviceNet type tunnel struct { tun tun.Device device *device.Device diff --git a/proxy/wireguard/tun_kernel_linux.go b/proxy/wireguard/tun_kernel_linux.go index c03ac2848675..a1ec9267dcbf 100644 --- a/proxy/wireguard/tun_kernel_linux.go +++ b/proxy/wireguard/tun_kernel_linux.go @@ -18,6 +18,8 @@ import ( wgtun "golang.zx2c4.com/wireguard/tun" ) +var _ Tunnel = (*deviceNet)(nil) + type deviceNet struct { tunnel dialer net.Dialer @@ -63,6 +65,11 @@ func (d *deviceNet) Close() (err error) { errs = append(errs, fmt.Errorf("failed to delete route: %w", err)) } } + for _, rule := range d.iptManglePreRoutingRules { + if err = d.ipt.DeleteRule(iptables.TableMangle, iptables.ChainPrerouting, rule...); err != nil { + errs = append(errs, fmt.Errorf("failed to delete iptables rule: %w", err)) + } + } if err = d.tunnel.Close(); err != nil { errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err)) } @@ -233,7 +240,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo _, err = out.ipt.EnsureRule(iptables.Append, iptables.TableMangle, iptables.ChainPrerouting, rule...) if err != nil { - return nil, fmt.Errorf("failed to add rule %s: %w", rule, err) + return nil, fmt.Errorf("failed to add iptable rule %s: %w", rule, err) } } out.tun = wgt From e8c8fac969a2f87009e7b6229b421d9f44b14d79 Mon Sep 17 00:00:00 2001 From: kunsonxs Date: Mon, 27 Nov 2023 11:40:10 +0800 Subject: [PATCH 3/3] fix : wireguard kernel mode checks --- infra/conf/wireguard.go | 12 ++++++------ proxy/wireguard/wireguard_linux.go | 7 +++++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index 4a189b54e93d..a0f0d46179af 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -116,12 +116,6 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { return nil, newError("unsupported domain strategy: ", c.DomainStrategy) } - // check device exist for wireguard setup - // module "golang.zx2c4.com/wireguard" only support linux and require /dev/net/tun - if wireguard.IsLinux() && !wireguard.CheckUnixKernelTunDeviceEnabled() { - return nil, newError("wireguard module require device /dev/net/tun") - } - config.IsClient = c.IsClient if c.IsClient { if support := wireguard.CheckUnixKernelTunSupported(); c.KernelMode == nil { @@ -135,6 +129,12 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { if !c.IsClient { config.KernelMode = false } + + // check device exist for wireguard setup + // module "golang.zx2c4.com/wireguard" on linux require /dev/net/tun for userspace implementation + if wireguard.IsLinux() && !wireguard.CheckUnixKernelTunDeviceEnabled() { + return nil, newError(`wireguard userspace mode require device "/dev/net/tun"`) + } return config, nil } diff --git a/proxy/wireguard/wireguard_linux.go b/proxy/wireguard/wireguard_linux.go index 2d84b0f1d5b5..34ce8f920806 100644 --- a/proxy/wireguard/wireguard_linux.go +++ b/proxy/wireguard/wireguard_linux.go @@ -46,9 +46,12 @@ func CheckUnixKernelIPv6IsEnabled() bool { // CheckUnixKernelTunSupported returns true if kernel tun is supported. // 1. check if the current process has CAP_NET_ADMIN capability // 2. check if /proc/sys/net/ipv4/conf/all/src_valid_mark exists and is set to 1 -// 3. check if iptables is available +// 3. check if /dev/net/tun exists +// 4. check if iptables is available func CheckUnixKernelTunSupported() bool { - if !CheckUnixKernelTunDeviceEnabled() || !CheckUnixKernelNetAdminCapEnabled() { + if !CheckUnixKernelNetAdminCapEnabled() || + !CheckUnixKernelIPv4SrcValidMarkEnabled() || + !CheckUnixKernelTunDeviceEnabled() { return false } outCmd := exec.Command("sh", "-c", "command -v iptables")