Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ocpp1.6_test/ocpp16_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
websocketServer.MethodCalled("NewClient", websocketId, client)
}

func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
websocketServer.CheckClientHandler = handler
}

Expand Down
2 changes: 1 addition & 1 deletion ocpp2.0.1_test/ocpp2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
websocketServer.MethodCalled("NewClient", websocketId, client)
}

func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
websocketServer.CheckClientHandler = handler
}

Expand Down
2 changes: 1 addition & 1 deletion ocppj/ocppj_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
websocketServer.MethodCalled("NewClient", websocketId, client)
}

func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
websocketServer.CheckClientHandler = handler
}

Expand Down
13 changes: 8 additions & 5 deletions ws/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (e HttpConnectionError) Error() string {

// ---------------------- SERVER ----------------------

type CheckClientHandler func(id string, r *http.Request) bool
type CheckClientHandler func(id string, r *http.Request) (string, bool)

// WsServer defines a websocket server, which passively listens for incoming connections on ws or wss protocol.
// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks.
Expand Down Expand Up @@ -249,7 +249,7 @@ type WsServer interface {
SetCheckOriginHandler(handler func(r *http.Request) bool)
// SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform
// custom client connection checks.
SetCheckClientHandler(handler func(id string, r *http.Request) bool)
SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool))
// Addr gives the address on which the server is listening, useful if, for
// example, the port is system-defined (set to 0).
Addr() *net.TCPAddr
Expand All @@ -262,7 +262,7 @@ type Server struct {
connections map[string]*WebSocket
httpServer *http.Server
messageHandler func(ws Channel, data []byte) error
checkClientHandler func(id string, r *http.Request) bool
checkClientHandler func(id string, r *http.Request) (string, bool)
newClientHandler func(ws Channel)
disconnectedHandler func(ws Channel)
basicAuthHandler func(username string, password string) bool
Expand Down Expand Up @@ -319,7 +319,7 @@ func (server *Server) SetMessageHandler(handler func(ws Channel, data []byte) er
server.messageHandler = handler
}

func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
server.checkClientHandler = handler
}

Expand Down Expand Up @@ -502,12 +502,15 @@ out:
}

if server.checkClientHandler != nil {
ok := server.checkClientHandler(id, r)
newId, ok := server.checkClientHandler(id, r)
if !ok {
server.error(fmt.Errorf("client validation: invalid client"))
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if len(newId) > 0 {
id = newId
}
}

// Upgrade websocket
Expand Down
4 changes: 2 additions & 2 deletions ws/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,8 @@ func TestCustomCheckClientHandler(t *testing.T) {
wsServer.SetNewClientHandler(func(ws Channel) {
connected <- true
})
wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) bool {
return id == clientId
wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) (string, bool) {
return clientId, id == clientId
})
go wsServer.Start(serverPort, serverPath)
time.Sleep(500 * time.Millisecond)
Expand Down