diff --git a/ocpp1.6_test/ocpp16_test.go b/ocpp1.6_test/ocpp16_test.go index 4a826bd1..d913f0f2 100644 --- a/ocpp1.6_test/ocpp16_test.go +++ b/ocpp1.6_test/ocpp16_test.go @@ -93,7 +93,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) { websocketServer.CheckClientHandler = handler } diff --git a/ocpp2.0.1_test/ocpp2_test.go b/ocpp2.0.1_test/ocpp2_test.go index edf56201..0642cbc8 100644 --- a/ocpp2.0.1_test/ocpp2_test.go +++ b/ocpp2.0.1_test/ocpp2_test.go @@ -106,7 +106,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) { websocketServer.CheckClientHandler = handler } diff --git a/ocppj/ocppj_test.go b/ocppj/ocppj_test.go index e01cb257..dbaecec7 100644 --- a/ocppj/ocppj_test.go +++ b/ocppj/ocppj_test.go @@ -103,7 +103,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) { websocketServer.CheckClientHandler = handler } diff --git a/ws/websocket.go b/ws/websocket.go index fd8d6831..07fc03ec 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -160,7 +160,7 @@ func (e HttpConnectionError) Error() string { // ---------------------- SERVER ---------------------- -type CheckClientHandler func(id string, r *http.Request) bool +type CheckClientHandler func(id string, r *http.Request) (string, bool) // WsServer defines a websocket server, which passively listens for incoming connections on ws or wss protocol. // The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks. @@ -249,7 +249,7 @@ type WsServer interface { SetCheckOriginHandler(handler func(r *http.Request) bool) // SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform // custom client connection checks. - SetCheckClientHandler(handler func(id string, r *http.Request) bool) + SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) // Addr gives the address on which the server is listening, useful if, for // example, the port is system-defined (set to 0). Addr() *net.TCPAddr @@ -262,7 +262,7 @@ type Server struct { connections map[string]*WebSocket httpServer *http.Server messageHandler func(ws Channel, data []byte) error - checkClientHandler func(id string, r *http.Request) bool + checkClientHandler func(id string, r *http.Request) (string, bool) newClientHandler func(ws Channel) disconnectedHandler func(ws Channel) basicAuthHandler func(username string, password string) bool @@ -319,7 +319,7 @@ func (server *Server) SetMessageHandler(handler func(ws Channel, data []byte) er server.messageHandler = handler } -func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) { server.checkClientHandler = handler } @@ -502,12 +502,15 @@ out: } if server.checkClientHandler != nil { - ok := server.checkClientHandler(id, r) + newId, ok := server.checkClientHandler(id, r) if !ok { server.error(fmt.Errorf("client validation: invalid client")) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } + if len(newId) > 0 { + id = newId + } } // Upgrade websocket diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 399f7259..dbcf70c8 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -715,8 +715,8 @@ func TestCustomCheckClientHandler(t *testing.T) { wsServer.SetNewClientHandler(func(ws Channel) { connected <- true }) - wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) bool { - return id == clientId + wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) (string, bool) { + return clientId, id == clientId }) go wsServer.Start(serverPort, serverPath) time.Sleep(500 * time.Millisecond)