Skip to content

Commit 03ab154

Browse files
Allow intercepting the client id inside the checkClientHandler
1 parent 717e26d commit 03ab154

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

ocpp1.6_test/ocpp16_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
9393
websocketServer.MethodCalled("NewClient", websocketId, client)
9494
}
9595

96-
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
96+
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
9797
websocketServer.CheckClientHandler = handler
9898
}
9999

ocpp2.0.1_test/ocpp2_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
106106
websocketServer.MethodCalled("NewClient", websocketId, client)
107107
}
108108

109-
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
109+
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
110110
websocketServer.CheckClientHandler = handler
111111
}
112112

ocppj/ocppj_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client
103103
websocketServer.MethodCalled("NewClient", websocketId, client)
104104
}
105105

106-
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
106+
func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
107107
websocketServer.CheckClientHandler = handler
108108
}
109109

ws/websocket.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func (e HttpConnectionError) Error() string {
160160

161161
// ---------------------- SERVER ----------------------
162162

163-
type CheckClientHandler func(id string, r *http.Request) bool
163+
type CheckClientHandler func(id string, r *http.Request) (string, bool)
164164

165165
// WsServer defines a websocket server, which passively listens for incoming connections on ws or wss protocol.
166166
// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks.
@@ -249,7 +249,7 @@ type WsServer interface {
249249
SetCheckOriginHandler(handler func(r *http.Request) bool)
250250
// SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform
251251
// custom client connection checks.
252-
SetCheckClientHandler(handler func(id string, r *http.Request) bool)
252+
SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool))
253253
// Addr gives the address on which the server is listening, useful if, for
254254
// example, the port is system-defined (set to 0).
255255
Addr() *net.TCPAddr
@@ -262,7 +262,7 @@ type Server struct {
262262
connections map[string]*WebSocket
263263
httpServer *http.Server
264264
messageHandler func(ws Channel, data []byte) error
265-
checkClientHandler func(id string, r *http.Request) bool
265+
checkClientHandler func(id string, r *http.Request) (string, bool)
266266
newClientHandler func(ws Channel)
267267
disconnectedHandler func(ws Channel)
268268
basicAuthHandler func(username string, password string) bool
@@ -319,7 +319,7 @@ func (server *Server) SetMessageHandler(handler func(ws Channel, data []byte) er
319319
server.messageHandler = handler
320320
}
321321

322-
func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
322+
func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) (string, bool)) {
323323
server.checkClientHandler = handler
324324
}
325325

@@ -502,12 +502,15 @@ out:
502502
}
503503

504504
if server.checkClientHandler != nil {
505-
ok := server.checkClientHandler(id, r)
505+
newId, ok := server.checkClientHandler(id, r)
506506
if !ok {
507507
server.error(fmt.Errorf("client validation: invalid client"))
508508
http.Error(w, "Unauthorized", http.StatusUnauthorized)
509509
return
510510
}
511+
if len(newId) > 0 {
512+
id = newId
513+
}
511514
}
512515

513516
// Upgrade websocket

ws/websocket_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,8 +715,8 @@ func TestCustomCheckClientHandler(t *testing.T) {
715715
wsServer.SetNewClientHandler(func(ws Channel) {
716716
connected <- true
717717
})
718-
wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) bool {
719-
return id == clientId
718+
wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) (string, bool) {
719+
return clientId, id == clientId
720720
})
721721
go wsServer.Start(serverPort, serverPath)
722722
time.Sleep(500 * time.Millisecond)

0 commit comments

Comments
 (0)