diff --git a/client/client.go b/client/client.go index 82c3e29..be3d0ca 100644 --- a/client/client.go +++ b/client/client.go @@ -41,11 +41,19 @@ func WithHttpClient(client *http.Client) LogtoClientOption { } } +// WithTrustForwardedHeader sets whether to trust X-Forwarded-* headers for checking the request's origin, useful when behind a reverse proxy. +func WithTrustForwardedHeader(b bool) LogtoClientOption { + return func(c *LogtoClient) { + c.trustForwardedHeader = b + } +} + type LogtoClient struct { - httpClient *http.Client - logtoConfig *LogtoConfig - storage Storage - accessTokenMap map[string]AccessToken + httpClient *http.Client + logtoConfig *LogtoConfig + storage Storage + accessTokenMap map[string]AccessToken + trustForwardedHeader bool } func NewLogtoClient(config *LogtoConfig, storage Storage, opts ...LogtoClientOption) *LogtoClient { diff --git a/client/handle_sign_in_callback.go b/client/handle_sign_in_callback.go index 65e5fec..fcc90fa 100644 --- a/client/handle_sign_in_callback.go +++ b/client/handle_sign_in_callback.go @@ -17,6 +17,9 @@ func (logtoClient *LogtoClient) HandleSignInCallback(request *http.Request) erro } callbackUri := GetOriginRequestUrl(request) + if logtoClient.trustForwardedHeader { + callbackUri = getForwaredRequestUrl(request) + } code, retrieveCodeErr := core.VerifyAndParseCodeFromCallbackUri(callbackUri, signInSession.RedirectUri, signInSession.State) if retrieveCodeErr != nil { return retrieveCodeErr diff --git a/client/util.go b/client/util.go index 5f12894..b8ef656 100644 --- a/client/util.go +++ b/client/util.go @@ -12,6 +12,26 @@ func GetOriginRequestUrl(request *http.Request) string { return getRequestProtocol(request) + "://" + request.Host + request.RequestURI } +func getForwaredRequestUrl(request *http.Request) string { + proto := getRequestProtocol(request) + host := getForwaredRequestHost(request) + uri := getForwaredRequestRequestUri(request) + return proto + "://" + host + uri +} +func getForwaredRequestHost(request *http.Request) string { + host := request.Header.Get("X-Forwarded-Host") + if host != "" { + return host + } + return request.Host +} +func getForwaredRequestRequestUri(request *http.Request) string { + uri := request.Header.Get("X-Forwarded-Url") + if uri != "" { + return uri + } + return request.RequestURI +} func getRequestProtocol(request *http.Request) string { if request.TLS != nil { return "https" diff --git a/client/util_test.go b/client/util_test.go index 9aba797..c92f8e1 100644 --- a/client/util_test.go +++ b/client/util_test.go @@ -130,3 +130,43 @@ func createTestToken(resource string) (string, error) { return token, nil } + +func TestGetForwaredRequestUrlShouldReturnXForwardedIfPresent(t *testing.T) { + req, err := http.NewRequest("GET", "http://example.com/path?query=1", nil) + assert.Nil(t, err) + // Ensure RequestURI is set like in real servers + req.RequestURI = "/path?query=1" + + req.Header.Add("X-Forwarded-Host", "forwarded.example.com") + req.Header.Add("X-Forwarded-Url", "/forwarded-path?query=2") + req.Header.Add("X-Forwarded-Proto", "https") + + url := getForwaredRequestUrl(req) + + assert.Equal(t, "https://forwarded.example.com/forwarded-path?query=2", url) +} + +func TestGetForwaredRequestHostShouldFallbackToRequestHost(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com/", nil) + req.RequestURI = "/" + + host := getForwaredRequestHost(req) + + assert.Equal(t, "example.com", host) + + req.Header.Add("X-Forwarded-Host", "proxied.example.com") + host = getForwaredRequestHost(req) + assert.Equal(t, "proxied.example.com", host) +} + +func TestGetForwaredRequestRequestUriShouldFallbackToRequestURI(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com/api", nil) + req.RequestURI = "/api" + + uri := getForwaredRequestRequestUri(req) + assert.Equal(t, "/api", uri) + + req.Header.Add("X-Forwarded-Url", "/proxied/api?x=1") + uri = getForwaredRequestRequestUri(req) + assert.Equal(t, "/proxied/api?x=1", uri) +}