Skip to content

Commit 9cbfe1f

Browse files
committed
Add basic tests for the verifier middleware
1 parent e6b8584 commit 9cbfe1f

File tree

6 files changed

+360
-39
lines changed

6 files changed

+360
-39
lines changed

go.mod

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ go 1.14
44

55
require (
66
github.com/coreos/go-oidc v2.2.1+incompatible
7-
github.com/google/go-cmp v0.4.1 // indirect
7+
github.com/google/go-cmp v0.4.1
8+
github.com/kelseyhightower/envconfig v1.4.0
89
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
9-
github.com/stretchr/testify v1.6.1 // indirect
10+
github.com/quay/jwtproxy v0.0.4
11+
github.com/stretchr/testify v1.6.1
1012
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect
1113
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect
12-
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
14+
gopkg.in/square/go-jose.v2 v2.5.1
15+
gopkg.in/yaml.v2 v2.3.0 // indirect
1316
)

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM
77
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
88
github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0=
99
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
10+
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
11+
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
1012
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
1113
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1214
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
1315
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
16+
github.com/quay/jwtproxy v0.0.4 h1:M7YZxrqLaY0MA20AkWqH+1HGFjxQPLmNrC8TjrkfbwQ=
17+
github.com/quay/jwtproxy v0.0.4/go.mod h1:Q0Zg96r0uvf49Ny3uRJ0Y09CCdtXU54LBntn6NZLShg=
1418
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
1519
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
1620
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -37,5 +41,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
3741
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3842
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
3943
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
44+
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
45+
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
4046
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
4147
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

main.go

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,45 +20,44 @@ import (
2020
"log"
2121
"net/http"
2222
"net/http/httputil"
23-
"os"
2423
"strings"
2524

2625
"github.com/coreos/go-oidc"
26+
"github.com/kelseyhightower/envconfig"
2727
)
2828

29-
var (
30-
ctx = context.Background()
31-
authDomain = os.Getenv("AUTHDOMAIN")
32-
certsURL = fmt.Sprintf("%s/cdn-cgi/access/certs", authDomain)
33-
34-
// policyAUD is your application AUD value
35-
policyAUD = os.Getenv("POLICYAUD")
36-
37-
// forwardHeader is the header to be set from the email claim embedded in the JWT token
38-
forwardHeader = os.Getenv("FORWARDHEADER")
29+
const (
30+
// CFJWTHeader is the header key set by Cloudflare Access after a successful authentication
31+
CFJWTHeader = "Cf-Access-Jwt-Assertion"
32+
)
3933

40-
// forwardHost is the host to bet used to forward the request. If set it will override the Host
41-
// header of the original request
42-
forwardHost = os.Getenv("FORWARDHOST")
34+
// CloudflareClaim holds the claims about the End-User/Authentication event.
35+
type CloudflareClaim struct {
36+
Email string `json:"email"`
37+
Type string `json:"type"`
38+
}
4339

44-
// listenAddr is the port where this proxy will be listening
45-
listenAddr = os.Getenv("ADDR")
40+
// Config is the general configuration (read from environment variables)
41+
type Config struct {
42+
AuthDomain string
43+
PolicyAUD string
44+
ForwardHeader string
45+
ForwardHost string
46+
ListenAddr string `envconfig:"ADDR"`
47+
}
4648

47-
config = &oidc.Config{
48-
ClientID: policyAUD,
49-
}
50-
keySet = oidc.NewRemoteKeySet(ctx, certsURL)
51-
verifier = oidc.NewVerifier(authDomain, keySet, config)
49+
var (
50+
ctx = context.Background()
5251
)
5352

5453
// VerifyToken is a middleware to verify a CF Access token
55-
func VerifyToken(next http.Handler) http.Handler {
54+
func VerifyToken(next http.Handler, tokenVerifier *oidc.IDTokenVerifier, cfg *Config) http.Handler {
5655
fn := func(w http.ResponseWriter, r *http.Request) {
5756
headers := r.Header
5857

5958
// Make sure that the incoming request has our token header
6059
// Could also look in the cookies for CF_AUTHORIZATION
61-
accessJWT := headers.Get("Cf-Access-Jwt-Assertion")
60+
accessJWT := headers.Get(CFJWTHeader)
6261
if accessJWT == "" {
6362
w.WriteHeader(http.StatusUnauthorized)
6463
w.Write([]byte("No token on the request"))
@@ -67,26 +66,22 @@ func VerifyToken(next http.Handler) http.Handler {
6766

6867
// Verify the access token
6968
ctx := r.Context()
70-
token, err := verifier.Verify(ctx, accessJWT)
69+
token, err := tokenVerifier.Verify(ctx, accessJWT)
7170
if err != nil {
7271
w.WriteHeader(http.StatusUnauthorized)
7372
w.Write([]byte(fmt.Sprintf("Invalid token: %s", err.Error())))
7473
return
7574
}
7675

7776
// Extract custom claims
78-
var claims struct {
79-
Email string `json:"email"`
80-
Type string `json:"type"`
81-
}
82-
77+
var claims CloudflareClaim
8378
if err := token.Claims(&claims); err != nil {
8479
w.WriteHeader(http.StatusUnauthorized)
8580
w.Write([]byte(fmt.Sprintf("Invalid claims in token: %s", err.Error())))
8681
}
8782

8883
// set the authentication forward header before proxying the request
89-
r.Header.Add(forwardHeader, claims.Email)
84+
r.Header.Add(cfg.ForwardHeader, claims.Email)
9085
log.Printf("Authenticated as: %s", claims.Email)
9186

9287
next.ServeHTTP(w, r)
@@ -96,23 +91,40 @@ func VerifyToken(next http.Handler) http.Handler {
9691
}
9792

9893
func main() {
94+
var cfg Config
95+
err := envconfig.Process("", &cfg)
96+
if err != nil {
97+
log.Fatal(err.Error())
98+
}
99+
100+
var (
101+
certsURL = fmt.Sprintf("%s/cdn-cgi/access/certs", cfg.AuthDomain)
102+
103+
config = &oidc.Config{
104+
ClientID: cfg.PolicyAUD,
105+
}
106+
keySet = oidc.NewRemoteKeySet(ctx, certsURL)
107+
verifier = oidc.NewVerifier(cfg.AuthDomain, keySet, config)
108+
)
109+
99110
director := func(req *http.Request) {
100111
req.Header.Add("X-Forwarded-Host", req.Host)
101112
req.Header.Add("X-Origin-Host", "cloudflare-access-proxy")
113+
// TODO: should we trust on the Schema of the original request?
102114
req.URL.Scheme = "http"
103115

104-
if len(strings.TrimSpace(forwardHost)) > 0 {
105-
req.URL.Host = forwardHost
116+
if len(strings.TrimSpace(cfg.ForwardHost)) > 0 {
117+
req.URL.Host = cfg.ForwardHost
106118
}
107119
}
108120

109121
proxy := &httputil.ReverseProxy{Director: director}
110122
http.Handle("/", VerifyToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111123
proxy.ServeHTTP(w, r)
112-
})))
124+
}), verifier, &cfg))
113125

114-
log.Printf("Listening on %s", listenAddr)
115-
if err := http.ListenAndServe(listenAddr, nil); err != nil {
116-
log.Fatalf("Unable to start server on [%s], error: %s", listenAddr, err.Error())
126+
log.Printf("Listening on %s", cfg.ListenAddr)
127+
if err := http.ListenAndServe(cfg.ListenAddr, nil); err != nil {
128+
log.Fatalf("Unable to start server on [%s], error: %s", cfg.ListenAddr, err.Error())
117129
}
118130
}

0 commit comments

Comments
 (0)