Skip to content

Commit 6944194

Browse files
committed
allow overriding duplicate connection handling behavior
1 parent 607009a commit 6944194

File tree

2 files changed

+113
-33
lines changed

2 files changed

+113
-33
lines changed

ws/server.go

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ import (
1919

2020
type CheckClientHandler func(id string, r *http.Request) bool
2121

22+
type DuplicateConnectionBehavior uint8
23+
24+
const (
25+
DuplicateConnectionBehaviorKeepCurrent DuplicateConnectionBehavior = iota
26+
DuplicateConnectionBehaviorKeepNew
27+
)
28+
2229
// Server defines a websocket server, which passively listens for incoming connections on ws or wss protocol.
2330
// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks.
2431
//
@@ -117,6 +124,13 @@ type Server interface {
117124
//
118125
// Changes to the http request at runtime may lead to undefined behavior.
119126
SetCheckClientHandler(handler CheckClientHandler)
127+
// SetDuplicateConnectionBehavior sets the behavior for how duplicate connections from the same charge point ID should be
128+
// handled. The default is to not allow the new connection, as long as the charge point ID is already connected, but this can
129+
// be overridden to instead close the current connection and allow the new connection.
130+
//
131+
// This has some important security considerations; it could allow malicious parties from forcefully disconnecting valid
132+
// chargers from the server, especially if not running with aut
133+
SetDuplicateConnectionBehavior(behavior DuplicateConnectionBehavior)
120134
// Addr gives the address on which the server is listening, useful if, for
121135
// example, the port is system-defined (set to 0).
122136
Addr() *net.TCPAddr
@@ -130,22 +144,23 @@ type Server interface {
130144
//
131145
// Use the NewServer function to create a new server.
132146
type server struct {
133-
connections map[string]*webSocket
134-
httpServer *http.Server
135-
messageHandler func(ws Channel, data []byte) error
136-
chargePointIdResolver func(*http.Request) (string, error)
137-
checkClientHandler CheckClientHandler
138-
newClientHandler func(ws Channel)
139-
disconnectedHandler func(ws Channel)
140-
basicAuthHandler func(username string, password string) bool
141-
tlsCertificatePath string
142-
tlsCertificateKey string
143-
timeoutConfig ServerTimeoutConfig
144-
upgrader websocket.Upgrader
145-
errC chan error
146-
connMutex sync.RWMutex
147-
addr *net.TCPAddr
148-
httpHandler *mux.Router
147+
connections map[string]*webSocket
148+
httpServer *http.Server
149+
messageHandler func(ws Channel, data []byte) error
150+
chargePointIdResolver func(*http.Request) (string, error)
151+
checkClientHandler CheckClientHandler
152+
duplicateConnectionBehavior DuplicateConnectionBehavior
153+
newClientHandler func(ws Channel)
154+
disconnectedHandler func(ws Channel)
155+
basicAuthHandler func(username string, password string) bool
156+
tlsCertificatePath string
157+
tlsCertificateKey string
158+
timeoutConfig ServerTimeoutConfig
159+
upgrader websocket.Upgrader
160+
errC chan error
161+
connMutex sync.RWMutex
162+
addr *net.TCPAddr
163+
httpHandler *mux.Router
149164
}
150165

151166
// ServerOpt is a function that can be used to set options on a server during creation.
@@ -183,10 +198,11 @@ func WithServerTLSConfig(certificatePath string, certificateKey string, tlsConfi
183198
func NewServer(opts ...ServerOpt) Server {
184199
router := mux.NewRouter()
185200
s := &server{
186-
httpServer: &http.Server{},
187-
timeoutConfig: NewServerTimeoutConfig(),
188-
upgrader: websocket.Upgrader{Subprotocols: []string{}},
189-
httpHandler: router,
201+
httpServer: &http.Server{},
202+
timeoutConfig: NewServerTimeoutConfig(),
203+
upgrader: websocket.Upgrader{Subprotocols: []string{}},
204+
httpHandler: router,
205+
duplicateConnectionBehavior: DuplicateConnectionBehaviorKeepCurrent,
190206
chargePointIdResolver: func(r *http.Request) (string, error) {
191207
url := r.URL
192208
return path.Base(url.Path), nil
@@ -206,6 +222,10 @@ func (s *server) SetCheckClientHandler(handler CheckClientHandler) {
206222
s.checkClientHandler = handler
207223
}
208224

225+
func (s *server) SetDuplicateConnectionBehavior(behavior DuplicateConnectionBehavior) {
226+
s.duplicateConnectionBehavior = behavior
227+
}
228+
209229
func (s *server) SetNewClientHandler(handler ConnectedHandler) {
210230
s.newClientHandler = handler
211231
}
@@ -425,15 +445,29 @@ out:
425445
}
426446
// Check whether client exists
427447
s.connMutex.Lock()
428-
// There is already a connection with the same ID. Close the new one immediately with a PolicyViolation.
429-
if _, exists := s.connections[id]; exists {
430-
s.connMutex.Unlock()
431-
s.error(fmt.Errorf("client %s already exists, closing duplicate client", id))
432-
_ = conn.WriteControl(websocket.CloseMessage,
433-
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"),
434-
time.Now().Add(s.timeoutConfig.WriteWait))
435-
_ = conn.Close()
436-
return
448+
switch s.duplicateConnectionBehavior {
449+
case DuplicateConnectionBehaviorKeepNew:
450+
// There is already a connection with the same ID. Close the old one, and allow the new connection. This has security
451+
// implications, see the note on the SetDuplicateConnectionBehavior func.
452+
if currentConn, exists := s.connections[id]; exists {
453+
s.connMutex.Unlock()
454+
s.error(fmt.Errorf("client %s already exists, closing existing client", id))
455+
_ = currentConn.connection.WriteControl(websocket.CloseMessage,
456+
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID has reconnected"),
457+
time.Now().Add(s.timeoutConfig.WriteWait))
458+
_ = currentConn.connection.Close()
459+
}
460+
default:
461+
// There is already a connection with the same ID. Close the new one immediately with a PolicyViolation.
462+
if _, exists := s.connections[id]; exists {
463+
s.connMutex.Unlock()
464+
s.error(fmt.Errorf("client %s already exists, closing duplicate client", id))
465+
_ = conn.WriteControl(websocket.CloseMessage,
466+
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"),
467+
time.Now().Add(s.timeoutConfig.WriteWait))
468+
_ = conn.Close()
469+
return
470+
}
437471
}
438472
// Create web socket for client, state is automatically set to connected
439473
ws := newWebSocket(

ws/websocket_test.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ func (s *WebSocketSuite) TestServerStartErrors() {
499499
s.NotNil(r)
500500
}
501501

502-
func (s *WebSocketSuite) TestClientDuplicateConnection() {
502+
func (s *WebSocketSuite) TestClientDuplicateConnectionWithKeepCurrentConnectionBehavior() {
503503
s.server = newWebsocketServer(s.T(), nil)
504504
s.server.SetNewClientHandler(func(ws Channel) {
505505
})
@@ -531,9 +531,55 @@ func (s *WebSocketSuite) TestClientDuplicateConnection() {
531531
})
532532
err = wsClient2.Start(u.String())
533533
s.NoError(err)
534-
// Expect connection to be closed immediately
535-
_, ok := <-disconnectC
536-
s.True(ok)
534+
// Expect new connection to be closed immediately
535+
select {
536+
case _, ok := <-disconnectC:
537+
s.True(ok, "expected new client to have been disconnected")
538+
case <-time.After(1 * time.Second):
539+
s.Fail("timeout waiting for new client to disconnect")
540+
}
541+
}
542+
543+
func (s *WebSocketSuite) TestClientDuplicateConnectionWithKeepNewConnectionBehavior() {
544+
s.server = newWebsocketServer(s.T(), nil)
545+
s.server.SetNewClientHandler(func(ws Channel) {
546+
})
547+
s.server.SetDuplicateConnectionBehavior(DuplicateConnectionBehaviorKeepNew)
548+
// Start server
549+
go s.server.Start(serverPort, serverPath)
550+
time.Sleep(100 * time.Millisecond)
551+
// Connect client 1
552+
disconnectC := make(chan struct{})
553+
s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) {
554+
return nil, nil
555+
})
556+
s.client.SetDisconnectedHandler(func(err error) {
557+
s.IsType(&websocket.CloseError{}, err)
558+
var wsErr *websocket.CloseError
559+
ok := errors.As(err, &wsErr)
560+
s.True(ok)
561+
s.Equal(websocket.ClosePolicyViolation, wsErr.Code)
562+
s.Equal("a connection with this ID has reconnected", wsErr.Text)
563+
s.client.SetDisconnectedHandler(nil)
564+
disconnectC <- struct{}{}
565+
})
566+
host := fmt.Sprintf("localhost:%v", serverPort)
567+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
568+
err := s.client.Start(u.String())
569+
s.NoError(err)
570+
// Connect client 2
571+
wsClient2 := newWebsocketClient(s.T(), func(data []byte) ([]byte, error) {
572+
return nil, nil
573+
})
574+
err = wsClient2.Start(u.String())
575+
s.NoError(err)
576+
// Expect current connection to be closed immediately
577+
select {
578+
case _, ok := <-disconnectC:
579+
s.True(ok, "expected current client to have been disconnected")
580+
case <-time.After(1 * time.Second):
581+
s.Fail("timeout waiting for current client to disconnect")
582+
}
537583
}
538584

539585
func (s *WebSocketSuite) TestServerStopConnection() {

0 commit comments

Comments
 (0)