diff --git a/ws/server.go b/ws/server.go index 7aaf03b2..240ff139 100644 --- a/ws/server.go +++ b/ws/server.go @@ -19,6 +19,13 @@ import ( type CheckClientHandler func(id string, r *http.Request) bool +type DuplicateConnectionBehavior uint8 + +const ( + DuplicateConnectionBehaviorKeepCurrent DuplicateConnectionBehavior = iota + DuplicateConnectionBehaviorKeepNew +) + // Server defines a websocket server, which passively listens for incoming connections on ws or wss protocol. // The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks. // @@ -117,6 +124,13 @@ type Server interface { // // Changes to the http request at runtime may lead to undefined behavior. SetCheckClientHandler(handler CheckClientHandler) + // SetDuplicateConnectionBehavior sets the behavior for how duplicate connections from the same charge point ID should be + // handled. The default is to not allow the new connection, as long as the charge point ID is already connected, but this can + // be overridden to instead close the current connection and allow the new connection. + // + // This has some important security considerations; it could allow malicious parties from forcefully disconnecting valid + // chargers from the server, especially if not running with aut + SetDuplicateConnectionBehavior(behavior DuplicateConnectionBehavior) // Addr gives the address on which the server is listening, useful if, for // example, the port is system-defined (set to 0). Addr() *net.TCPAddr @@ -130,22 +144,23 @@ type Server interface { // // Use the NewServer function to create a new server. type server struct { - connections map[string]*webSocket - httpServer *http.Server - messageHandler func(ws Channel, data []byte) error - chargePointIdResolver func(*http.Request) (string, error) - checkClientHandler CheckClientHandler - newClientHandler func(ws Channel) - disconnectedHandler func(ws Channel) - basicAuthHandler func(username string, password string) bool - tlsCertificatePath string - tlsCertificateKey string - timeoutConfig ServerTimeoutConfig - upgrader websocket.Upgrader - errC chan error - connMutex sync.RWMutex - addr *net.TCPAddr - httpHandler *mux.Router + connections map[string]*webSocket + httpServer *http.Server + messageHandler func(ws Channel, data []byte) error + chargePointIdResolver func(*http.Request) (string, error) + checkClientHandler CheckClientHandler + duplicateConnectionBehavior DuplicateConnectionBehavior + newClientHandler func(ws Channel) + disconnectedHandler func(ws Channel) + basicAuthHandler func(username string, password string) bool + tlsCertificatePath string + tlsCertificateKey string + timeoutConfig ServerTimeoutConfig + upgrader websocket.Upgrader + errC chan error + connMutex sync.RWMutex + addr *net.TCPAddr + httpHandler *mux.Router } // ServerOpt is a function that can be used to set options on a server during creation. @@ -183,10 +198,11 @@ func WithServerTLSConfig(certificatePath string, certificateKey string, tlsConfi func NewServer(opts ...ServerOpt) Server { router := mux.NewRouter() s := &server{ - httpServer: &http.Server{}, - timeoutConfig: NewServerTimeoutConfig(), - upgrader: websocket.Upgrader{Subprotocols: []string{}}, - httpHandler: router, + httpServer: &http.Server{}, + timeoutConfig: NewServerTimeoutConfig(), + upgrader: websocket.Upgrader{Subprotocols: []string{}}, + httpHandler: router, + duplicateConnectionBehavior: DuplicateConnectionBehaviorKeepCurrent, chargePointIdResolver: func(r *http.Request) (string, error) { url := r.URL return path.Base(url.Path), nil @@ -206,6 +222,10 @@ func (s *server) SetCheckClientHandler(handler CheckClientHandler) { s.checkClientHandler = handler } +func (s *server) SetDuplicateConnectionBehavior(behavior DuplicateConnectionBehavior) { + s.duplicateConnectionBehavior = behavior +} + func (s *server) SetNewClientHandler(handler ConnectedHandler) { s.newClientHandler = handler } @@ -425,15 +445,28 @@ out: } // Check whether client exists s.connMutex.Lock() - // There is already a connection with the same ID. Close the new one immediately with a PolicyViolation. - if _, exists := s.connections[id]; exists { - s.connMutex.Unlock() - s.error(fmt.Errorf("client %s already exists, closing duplicate client", id)) - _ = conn.WriteControl(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"), - time.Now().Add(s.timeoutConfig.WriteWait)) - _ = conn.Close() - return + switch s.duplicateConnectionBehavior { + case DuplicateConnectionBehaviorKeepNew: + // There is already a connection with the same ID. Close the old one, and allow the new connection. This has security + // implications, see the note on the SetDuplicateConnectionBehavior func. + if currentConn, exists := s.connections[id]; exists { + s.error(fmt.Errorf("client %s already exists, closing existing client", id)) + _ = currentConn.connection.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID has reconnected"), + time.Now().Add(s.timeoutConfig.WriteWait)) + _ = currentConn.connection.Close() + } + default: + // There is already a connection with the same ID. Close the new one immediately with a PolicyViolation. + if _, exists := s.connections[id]; exists { + s.connMutex.Unlock() + s.error(fmt.Errorf("client %s already exists, closing duplicate client", id)) + _ = conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"), + time.Now().Add(s.timeoutConfig.WriteWait)) + _ = conn.Close() + return + } } // Create web socket for client, state is automatically set to connected ws := newWebSocket( diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 7cde5c1c..96d8af06 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -499,7 +499,7 @@ func (s *WebSocketSuite) TestServerStartErrors() { s.NotNil(r) } -func (s *WebSocketSuite) TestClientDuplicateConnection() { +func (s *WebSocketSuite) TestClientDuplicateConnectionWithKeepCurrentConnectionBehavior() { s.server = newWebsocketServer(s.T(), nil) s.server.SetNewClientHandler(func(ws Channel) { }) @@ -531,9 +531,55 @@ func (s *WebSocketSuite) TestClientDuplicateConnection() { }) err = wsClient2.Start(u.String()) s.NoError(err) - // Expect connection to be closed immediately - _, ok := <-disconnectC - s.True(ok) + // Expect new connection to be closed immediately + select { + case _, ok := <-disconnectC: + s.True(ok, "expected new client to have been disconnected") + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for new client to disconnect") + } +} + +func (s *WebSocketSuite) TestClientDuplicateConnectionWithKeepNewConnectionBehavior() { + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { + }) + s.server.SetDuplicateConnectionBehavior(DuplicateConnectionBehaviorKeepNew) + // Start server + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + // Connect client 1 + disconnectC := make(chan struct{}) + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + return nil, nil + }) + s.client.SetDisconnectedHandler(func(err error) { + s.IsType(&websocket.CloseError{}, err) + var wsErr *websocket.CloseError + ok := errors.As(err, &wsErr) + s.True(ok) + s.Equal(websocket.ClosePolicyViolation, wsErr.Code) + s.Equal("a connection with this ID has reconnected", wsErr.Text) + s.client.SetDisconnectedHandler(nil) + disconnectC <- struct{}{} + }) + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + err := s.client.Start(u.String()) + s.NoError(err) + // Connect client 2 + wsClient2 := newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + return nil, nil + }) + err = wsClient2.Start(u.String()) + s.NoError(err) + // Expect current connection to be closed immediately + select { + case _, ok := <-disconnectC: + s.True(ok, "expected current client to have been disconnected") + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for current client to disconnect") + } } func (s *WebSocketSuite) TestServerStopConnection() {