@@ -31,6 +31,37 @@ func init() {
3131 * Tests
3232 */
3333
34+ func TestServerRootHandler (t * testing.T ) {
35+ assert := assert .New (t )
36+ config = newDefaultConfig ()
37+
38+ // X-Forwarded headers should be read into request
39+ req := httptest .NewRequest ("POST" , "http://should-use-x-forwarded.com/should?ignore=me" , nil )
40+ req .Header .Add ("X-Forwarded-Method" , "GET" )
41+ req .Header .Add ("X-Forwarded-Proto" , "https" )
42+ req .Header .Add ("X-Forwarded-Host" , "example.com" )
43+ req .Header .Add ("X-Forwarded-Uri" , "/foo?q=bar" )
44+ NewServer ().RootHandler (httptest .NewRecorder (), req )
45+
46+ assert .Equal ("GET" , req .Method , "x-forwarded-method should be read into request" )
47+ assert .Equal ("example.com" , req .Host , "x-forwarded-host should be read into request" )
48+ assert .Equal ("/foo" , req .URL .Path , "x-forwarded-uri should be read into request" )
49+ assert .Equal ("/foo?q=bar" , req .URL .RequestURI (), "x-forwarded-uri should be read into request" )
50+
51+ // Other X-Forwarded headers should be read in into request and original URL
52+ // should be preserved if X-Forwarded-Uri not present
53+ req = httptest .NewRequest ("POST" , "http://should-use-x-forwarded.com/should-not?ignore=me" , nil )
54+ req .Header .Add ("X-Forwarded-Method" , "GET" )
55+ req .Header .Add ("X-Forwarded-Proto" , "https" )
56+ req .Header .Add ("X-Forwarded-Host" , "example.com" )
57+ NewServer ().RootHandler (httptest .NewRecorder (), req )
58+
59+ assert .Equal ("GET" , req .Method , "x-forwarded-method should be read into request" )
60+ assert .Equal ("example.com" , req .Host , "x-forwarded-host should be read into request" )
61+ assert .Equal ("/should-not" , req .URL .Path , "request url should be preserved if x-forwarded-uri not present" )
62+ assert .Equal ("/should-not?ignore=me" , req .URL .RequestURI (), "request url should be preserved if x-forwarded-uri not present" )
63+ }
64+
3465func TestServerAuthHandlerInvalid (t * testing.T ) {
3566 assert := assert .New (t )
3667 config = newDefaultConfig ()
@@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
90121 config .Domains = []string {"test.com" }
91122
92123 // Should redirect expired cookie
93- req := newDefaultHttpRequest ( " /foo" )
124+ req := newHTTPRequest ( "GET" , "http://example.com /foo" )
94125 c := MakeCookie (
req ,
"[email protected] " )
95126 res , _ := doHttpRequest (req , c )
96- assert .Equal (307 , res .StatusCode , "request with expired cookie should be redirected" )
127+ require .Equal (t , 307 , res .StatusCode , "request with expired cookie should be redirected" )
97128
98129 // Check for CSRF cookie
99130 var cookie * http.Cookie
@@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
116147 config = newDefaultConfig ()
117148
118149 // Should allow valid request email
119- req := newDefaultHttpRequest ( " /foo" )
150+ req := newHTTPRequest ( "GET" , "http://example.com /foo" )
120151 c := MakeCookie (
req ,
"[email protected] " )
121152 config .Domains = []string {}
122153
@@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
131162
132163func TestServerAuthCallback (t * testing.T ) {
133164 assert := assert .New (t )
165+ require := require .New (t )
134166 config = newDefaultConfig ()
135167
136168 // Setup OAuth server
@@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
148180 }
149181
150182 // Should pass auth response request to callback
151- req := newDefaultHttpRequest ( " /_oauth" )
183+ req := newHTTPRequest ( "GET" , "http://example.com /_oauth" )
152184 res , _ := doHttpRequest (req , nil )
153185 assert .Equal (401 , res .StatusCode , "auth callback without cookie shouldn't be authorised" )
154186
155187 // Should catch invalid csrf cookie
156- req = newDefaultHttpRequest ("/_oauth?state=12345678901234567890123456789012:http://redirect" )
188+ nonce := "12345678901234567890123456789012"
189+ req = newHTTPRequest ("GET" , "http://example.com/_oauth?state=" + nonce + ":http://redirect" )
157190 c := MakeCSRFCookie (req , "nononononononononononononononono" )
158191 res , _ = doHttpRequest (req , c )
159192 assert .Equal (401 , res .StatusCode , "auth callback with invalid cookie shouldn't be authorised" )
160193
161194 // Should catch invalid provider cookie
162- req = newDefaultHttpRequest ( "/ _oauth?state=12345678901234567890123456789012 :invalid:http://redirect" )
163- c = MakeCSRFCookie (req , "12345678901234567890123456789012" )
195+ req = newHTTPRequest ( "GET" , "http://example.com/ _oauth?state=" + nonce + " :invalid:http://redirect" )
196+ c = MakeCSRFCookie (req , nonce )
164197 res , _ = doHttpRequest (req , c )
165198 assert .Equal (401 , res .StatusCode , "auth callback with invalid provider shouldn't be authorised" )
166199
167200 // Should redirect valid request
168- req = newDefaultHttpRequest ( "/ _oauth?state=12345678901234567890123456789012 :google:http://redirect" )
169- c = MakeCSRFCookie (req , "12345678901234567890123456789012" )
201+ req = newHTTPRequest ( "GET" , "http://example.com/ _oauth?state=" + nonce + " :google:http://redirect" )
202+ c = MakeCSRFCookie (req , nonce )
170203 res , _ = doHttpRequest (req , c )
171- assert .Equal (307 , res .StatusCode , "valid auth callback should be allowed" )
204+ require .Equal (307 , res .StatusCode , "valid auth callback should be allowed" )
172205
173206 fwd , _ := res .Location ()
174207 assert .Equal ("http" , fwd .Scheme , "valid request should be redirected to return url" )
@@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
360393 }
361394
362395 // Should block any request
363- req := newHttpRequest ("GET" , "https://example.com/" , " /" )
396+ req := newHTTPRequest ("GET" , "https://example.com/" )
364397 res , _ := doHttpRequest (req , nil )
365398 assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
366399
367400 // Should allow matching request
368- req = newHttpRequest ("GET" , "https://api.example.com/" , " /" )
401+ req = newHTTPRequest ("GET" , "https://api.example.com/" )
369402 res , _ = doHttpRequest (req , nil )
370403 assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
371404
372405 // Should allow matching request
373- req = newHttpRequest ("GET" , "https://sub8.example.com/" , " /" )
406+ req = newHTTPRequest ("GET" , "https://sub8.example.com/" )
374407 res , _ = doHttpRequest (req , nil )
375408 assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
376409}
@@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
386419 }
387420
388421 // Should block any request
389- req := newHttpRequest ("GET" , "https://example.com/" , " /" )
422+ req := newHTTPRequest ("GET" , "https://example.com/" )
390423 res , _ := doHttpRequest (req , nil )
391424 assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
392425
393426 // Should allow matching request
394- req = newHttpRequest ("PUT" , "https://example.com/" , " /" )
427+ req = newHTTPRequest ("PUT" , "https://example.com/" )
395428 res , _ = doHttpRequest (req , nil )
396429 assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
397430}
@@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
441474 }
442475
443476 // Should block any request
444- req := newHttpRequest ("GET" , "https://example.com/" , " /?q=no" )
477+ req := newHTTPRequest ("GET" , "https://example.com/?q=no" )
445478 res , _ := doHttpRequest (req , nil )
446479 assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
447480
448481 // Should allow matching request
449- req = newHttpRequest ("GET" , "https://api.example.com/" , " /?q=test123" )
482+ req = newHTTPRequest ("GET" , "https://api.example.com/?q=test123" )
450483 res , _ = doHttpRequest (req , nil )
451484 assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
452485}
@@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
531564 return config
532565}
533566
567+ // TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
534568func newDefaultHttpRequest (uri string ) * http.Request {
535- return newHttpRequest ( " " , "http://example.com/" , uri )
569+ return newHTTPRequest ( "GET " , "http://example.com" + uri )
536570}
537571
538- func newHttpRequest (method , dest , uri string ) * http.Request {
539- r := httptest . NewRequest ( "" , "http://should-use-x-forwarded.com" , nil )
540- p , _ := url . Parse ( dest )
572+ func newHTTPRequest (method , target string ) * http.Request {
573+ u , _ := url . Parse ( target )
574+ r := httptest . NewRequest ( method , target , nil )
541575 r .Header .Add ("X-Forwarded-Method" , method )
542- r .Header .Add ("X-Forwarded-Proto" , p .Scheme )
543- r .Header .Add ("X-Forwarded-Host" , p .Host )
544- r .Header .Add ("X-Forwarded-Uri" , uri )
576+ r .Header .Add ("X-Forwarded-Proto" , u .Scheme )
577+ r .Header .Add ("X-Forwarded-Host" , u .Host )
578+ r .Header .Add ("X-Forwarded-Uri" , u . RequestURI () )
545579 return r
546580}
0 commit comments