@@ -99,10 +99,7 @@ pub enum ListenerEvent {
99
99
#[ cfg_attr( test, derive( PartialEq ) ) ]
100
100
pub enum SendError {
101
101
/// the chat service is no longer connected
102
- Disconnected {
103
- #[ cfg( test) ] // Useful for testing but otherwise unused
104
- reason : & ' static str ,
105
- } ,
102
+ Disconnected ( DisconnectedReason ) ,
106
103
/// an OS-level I/O error occurred
107
104
Io ( IoErrorKind ) ,
108
105
/// the message is larger than the configured limit
@@ -115,6 +112,20 @@ pub enum SendError {
115
112
InvalidRequest ( InvalidRequestError ) ,
116
113
}
117
114
115
+ #[ derive( Debug ) ]
116
+ #[ cfg_attr( test, derive( PartialEq ) ) ]
117
+ pub enum DisconnectedReason {
118
+ /// the server explicitly disconnected us because we connected elsewhere with the same credentials
119
+ ConnectedElsewhere ,
120
+ /// the server has disconnect us because the credentials we used to connect have become invalidated
121
+ ConnectionInvalidated ,
122
+ // the socket was closed, either by us or by the server, for some other reason.
123
+ SocketClosed {
124
+ #[ cfg( test) ] // Useful for testing but otherwise unused
125
+ reason : & ' static str ,
126
+ } ,
127
+ }
128
+
118
129
#[ derive( Debug ) ]
119
130
#[ cfg_attr( test, derive( PartialEq ) ) ]
120
131
pub enum InvalidRequestError {
@@ -378,10 +389,10 @@ impl Responder {
378
389
}
379
390
}
380
391
381
- Err ( SendError :: Disconnected {
392
+ Err ( SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
382
393
#[ cfg( test) ]
383
394
reason : "task exited without receiving response" ,
384
- } )
395
+ } ) )
385
396
}
386
397
}
387
398
@@ -668,6 +679,29 @@ async fn spawned_task_body<I: InnerConnection>(
668
679
task_result
669
680
}
670
681
682
+ /// Retrieves the final task error state and converts it to a `SendError`.
683
+ ///
684
+ /// This function waits for the task to finish if it hasn't already, and then
685
+ /// extracts the error state. It should only be called when we know the task is
686
+ /// ending or has already ended, such as when a send operation to the task has
687
+ /// failed, or it will hold the state lock for an unboundedly long time.
688
+ async fn get_task_finish_error ( state : & TokioMutex < TaskState > , _reason : & ' static str ) -> SendError {
689
+ // We're holding the lock here across an await point to prevent
690
+ // another method from also trying to wait for the task result and
691
+ // update state. Since the earlier send failed, the task must have
692
+ // dropped its receiver, and it doesn't do much after that so this
693
+ // should be a short wait.
694
+ let mut guard = state. lock ( ) . await ;
695
+ let finished_state = wait_for_task_to_finish ( & mut guard) . await . as_ref ( ) ;
696
+ match finished_state {
697
+ Ok ( _) => SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
698
+ #[ cfg( test) ]
699
+ reason : _reason,
700
+ } ) ,
701
+ Err ( err) => SendError :: from ( err) ,
702
+ }
703
+ }
704
+
671
705
async fn send_request (
672
706
state : & TokioMutex < TaskState > ,
673
707
request : PartialRequestProto ,
@@ -682,16 +716,16 @@ async fn send_request(
682
716
task : _,
683
717
} => request_tx. clone ( ) ,
684
718
TaskState :: SignaledToEnd ( _) => {
685
- return Err ( SendError :: Disconnected {
719
+ return Err ( SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
686
720
#[ cfg( test) ]
687
721
reason : "task was already signalled to end" ,
688
- } )
722
+ } ) )
689
723
}
690
724
TaskState :: Finished ( Ok ( _reason) ) => {
691
- return Err ( SendError :: Disconnected {
725
+ return Err ( SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
692
726
#[ cfg( test) ]
693
727
reason : "task already ended gracefully" ,
694
- } )
728
+ } ) )
695
729
}
696
730
TaskState :: Finished ( Err ( err) ) => return Err ( SendError :: from ( & * err) ) ,
697
731
}
@@ -708,36 +742,21 @@ async fn send_request(
708
742
. is_ok ( )
709
743
{
710
744
// The request was sent, now wait for the response to be sent back.
711
- let response =
712
- receiver
713
- . await
714
- . map_err ( |_: oneshot:: error:: RecvError | SendError :: Disconnected {
715
- #[ cfg( test) ]
716
- reason : "response channel sender was dropped" ,
717
- } ) ?;
718
- response. map_err ( SendError :: from)
719
- } else {
720
- // The request couldn't be sent to the task. We could give up now
721
- // and return SendError::Disconnected but that's not as useful as
722
- // something derived from the actual end status.
723
- let mut guard = state. lock ( ) . await ;
724
-
725
- // We're holding the lock here across an await point to prevent
726
- // another method from also trying to wait for the task result and
727
- // update state. Since the earlier send failed, the task must have
728
- // dropped its receiver, and it doesn't do much after that so this
729
- // should be a short wait.
730
- let finished_state = wait_for_task_to_finish ( & mut guard) . await . as_ref ( ) ;
731
-
732
- let send_error = finished_state. map_or_else ( SendError :: from, |_reason| {
733
- // The task exited successfully but our send still didn't go
734
- // through, so return an error.
735
- SendError :: Disconnected {
736
- #[ cfg( test) ]
737
- reason : "task ended gracefully before sending request" ,
745
+ match receiver. await {
746
+ Ok ( response) => response. map_err ( SendError :: from) ,
747
+ Err ( _) => {
748
+ // The sender was dropped without sending a response.
749
+ // This happens when the connection is closed while our request is in flight.
750
+ // Fetch the reason for the underlying connection failure, and return that as the
751
+ // reason for the request failure, to be most useful.
752
+ Err ( get_task_finish_error ( state, "response channel sender was dropped" ) . await )
738
753
}
739
- } ) ;
740
- Err ( send_error)
754
+ }
755
+ } else {
756
+ // We could not send the request at all, so the task must have ended, probably due to the connection
757
+ // closing. Fetch the reason for the underlying connection failure, and return that as the reason for
758
+ // the request failure.
759
+ Err ( get_task_finish_error ( state, "task ended gracefully before sending request" ) . await )
741
760
}
742
761
}
743
762
@@ -1081,23 +1100,37 @@ pub(super) fn decode_and_validate(data: &[u8]) -> Result<ChatMessageProto, ChatP
1081
1100
1082
1101
impl From < & TaskErrorState > for SendError {
1083
1102
fn from ( value : & TaskErrorState ) -> Self {
1084
- let _ = value;
1085
- SendError :: Disconnected {
1086
- #[ cfg( test) ]
1087
- reason : match value {
1088
- TaskErrorState :: SendFailed => "send failed" ,
1089
- TaskErrorState :: Panic ( _) => "chat task panicked" ,
1090
- TaskErrorState :: AbnormalServerClose { .. } => "server closed abnormally" ,
1091
- TaskErrorState :: ReceiveFailed => "receive failed" ,
1092
- TaskErrorState :: ServerIdleTooLong ( _) => "server idle too long" ,
1093
- TaskErrorState :: UnexpectedConnectionClose => "server closed unexpectedly" ,
1103
+ match value {
1104
+ TaskErrorState :: AbnormalServerClose { code, reason : _ } => match code {
1105
+ CloseCode :: Library ( CONNECTED_ELSEWHERE_CLOSE_CODE ) => {
1106
+ SendError :: Disconnected ( DisconnectedReason :: ConnectedElsewhere )
1107
+ }
1108
+ CloseCode :: Library ( CONNECTION_INVALIDATED_CLOSE_CODE ) => {
1109
+ SendError :: Disconnected ( DisconnectedReason :: ConnectionInvalidated )
1110
+ }
1111
+ _ => SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
1112
+ #[ cfg( test) ]
1113
+ reason : "server closed abnormally" ,
1114
+ } ) ,
1094
1115
} ,
1116
+ _ => SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
1117
+ #[ cfg( test) ]
1118
+ reason : match value {
1119
+ TaskErrorState :: SendFailed => "send failed" ,
1120
+ TaskErrorState :: Panic ( _) => "chat task panicked" ,
1121
+ // Already handled above, this is test-only code so fail-fast is desirable.
1122
+ TaskErrorState :: AbnormalServerClose { .. } => unreachable ! ( ) ,
1123
+ TaskErrorState :: ReceiveFailed => "receive failed" ,
1124
+ TaskErrorState :: ServerIdleTooLong ( _) => "server idle too long" ,
1125
+ TaskErrorState :: UnexpectedConnectionClose => "server closed unexpectedly" ,
1126
+ } ,
1127
+ } ) ,
1095
1128
}
1096
1129
}
1097
1130
}
1098
1131
1099
1132
impl From < TaskSendError > for SendError {
1100
- fn from ( value : TaskSendError ) -> SendError {
1133
+ fn from ( value : TaskSendError ) -> Self {
1101
1134
match value {
1102
1135
TaskSendError :: StreamSendFailed ( send_error) => send_error. into ( ) ,
1103
1136
TaskSendError :: InvalidResponse => SendError :: InvalidResponse ,
@@ -1135,10 +1168,12 @@ impl From<&TungsteniteSendError> for SendError {
1135
1168
fn from ( value : & TungsteniteSendError ) -> Self {
1136
1169
match value {
1137
1170
TungsteniteSendError :: Io ( io) => SendError :: Io ( io. kind ( ) ) ,
1138
- TungsteniteSendError :: ConnectionAlreadyClosed => SendError :: Disconnected {
1139
- #[ cfg( test) ]
1140
- reason : "task failure due to send failure" ,
1141
- } ,
1171
+ TungsteniteSendError :: ConnectionAlreadyClosed => {
1172
+ SendError :: Disconnected ( DisconnectedReason :: SocketClosed {
1173
+ #[ cfg( test) ]
1174
+ reason : "task failure due to send failure" ,
1175
+ } )
1176
+ }
1142
1177
TungsteniteSendError :: MessageTooLarge { size, max_size } => {
1143
1178
SendError :: MessageTooLarge {
1144
1179
size : * size,
@@ -1189,7 +1224,13 @@ impl From<TaskExitError> for crate::chat::SendError {
1189
1224
impl From < SendError > for super :: SendError {
1190
1225
fn from ( value : SendError ) -> Self {
1191
1226
match value {
1192
- SendError :: Disconnected { .. } => Self :: Disconnected ,
1227
+ SendError :: Disconnected ( DisconnectedReason :: SocketClosed { .. } ) => Self :: Disconnected ,
1228
+ SendError :: Disconnected ( DisconnectedReason :: ConnectedElsewhere ) => {
1229
+ Self :: ConnectedElsewhere
1230
+ }
1231
+ SendError :: Disconnected ( DisconnectedReason :: ConnectionInvalidated ) => {
1232
+ Self :: ConnectionInvalidated
1233
+ }
1193
1234
SendError :: Io ( error_kind) => {
1194
1235
Self :: WebSocket ( WebSocketServiceError :: Io ( error_kind. into ( ) ) )
1195
1236
}
@@ -2135,4 +2176,137 @@ mod test {
2135
2176
) ;
2136
2177
assert_matches ! ( listener_rx. try_recv( ) , Err ( TryRecvError :: Empty ) ) ;
2137
2178
}
2179
+
2180
+ #[ test_case(
2181
+ CloseCode :: from( CONNECTION_INVALIDATED_CLOSE_CODE ) , SendError :: Disconnected ( DisconnectedReason :: ConnectionInvalidated ) ;
2182
+ "CONNECTION_INVALIDATED_CLOSE_CODE results in ConnectionInvalidated"
2183
+ ) ]
2184
+ #[ test_case(
2185
+ CloseCode :: from( CONNECTED_ELSEWHERE_CLOSE_CODE ) , SendError :: Disconnected ( DisconnectedReason :: ConnectedElsewhere ) ;
2186
+ "CONNECTED_ELSEWHERE_CLOSE_CODE results in ConnectedElsewhere"
2187
+ ) ]
2188
+ #[ test_case(
2189
+ CloseCode :: Normal , SendError :: Disconnected ( DisconnectedReason :: SocketClosed { #[ cfg( test) ] reason: "server closed abnormally" } ) ;
2190
+ "Normal close results in Disconnected"
2191
+ ) ]
2192
+ #[ test_log:: test( tokio:: test( start_paused = true ) ) ]
2193
+ async fn send_after_ws_close_returns_proper_error (
2194
+ close_code : CloseCode ,
2195
+ expected_error : SendError ,
2196
+ ) {
2197
+ let ( chat, ( _inner_events, inner_responses) ) = fake:: new_chat ( Box :: new ( |_| ( ) ) ) ;
2198
+ assert ! ( chat. is_connected( ) . await ) ;
2199
+
2200
+ // Close the connection with the specific close code
2201
+ inner_responses
2202
+ . send (
2203
+ Outcome :: Finished ( Err ( NextEventError :: AbnormalServerClose {
2204
+ code : close_code,
2205
+ reason : format ! ( "close code: {close_code}" ) ,
2206
+ } ) )
2207
+ . into ( ) ,
2208
+ )
2209
+ . expect ( "can send close event" ) ;
2210
+
2211
+ let wait_for_disconnect = async {
2212
+ while chat. is_connected ( ) . await {
2213
+ tokio:: task:: yield_now ( ) . await ;
2214
+ }
2215
+ } ;
2216
+
2217
+ tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , wait_for_disconnect)
2218
+ . await
2219
+ . expect ( "chat disconnect does not take long" ) ;
2220
+
2221
+ // Try to send a request, which should fail with the expected error
2222
+ let send_result = chat
2223
+ . send ( Request {
2224
+ method : Method :: GET ,
2225
+ path : PathAndQuery :: from_static ( "/test" ) ,
2226
+ headers : HeaderMap :: default ( ) ,
2227
+ body : None ,
2228
+ } )
2229
+ . await ;
2230
+
2231
+ assert_eq ! ( send_result, Err ( expected_error) ) ;
2232
+ }
2233
+
2234
+ #[ test_case(
2235
+ CloseCode :: from( CONNECTION_INVALIDATED_CLOSE_CODE ) ,
2236
+ SendError :: Disconnected ( DisconnectedReason :: ConnectionInvalidated ) ;
2237
+ "CONNECTION_INVALIDATED_CLOSE_CODE should propagate correctly"
2238
+ ) ]
2239
+ #[ test_case(
2240
+ CloseCode :: from( CONNECTED_ELSEWHERE_CLOSE_CODE ) ,
2241
+ SendError :: Disconnected ( DisconnectedReason :: ConnectedElsewhere ) ;
2242
+ "CONNECTED_ELSEWHERE_CLOSE_CODE should propagate correctly"
2243
+ ) ]
2244
+ #[ test_case(
2245
+ CloseCode :: Normal ,
2246
+ SendError :: Disconnected ( DisconnectedReason :: SocketClosed { #[ cfg( test) ] reason: "server closed abnormally" } ) ;
2247
+ "Normal close results in Disconnected"
2248
+ ) ]
2249
+ #[ test_log:: test( tokio:: test( start_paused = true ) ) ]
2250
+ async fn connection_close_with_in_flight_request (
2251
+ close_code : CloseCode ,
2252
+ expected_error : SendError ,
2253
+ ) {
2254
+ // Create channels for listener events
2255
+ let ( listener_tx, mut listener_rx) = mpsc:: unbounded_channel ( ) ;
2256
+ let ( chat, ( mut chat_events, inner_responses) ) = fake:: new_chat ( Box :: new ( move |evt| {
2257
+ let _ = listener_tx. send ( evt) ;
2258
+ } ) ) ;
2259
+
2260
+ // Start a request but don't complete it right away
2261
+ let send_task = tokio:: spawn ( async move {
2262
+ chat. send ( Request {
2263
+ method : Method :: GET ,
2264
+ path : PathAndQuery :: from_static ( "/test" ) ,
2265
+ headers : HeaderMap :: default ( ) ,
2266
+ body : None ,
2267
+ } )
2268
+ . await
2269
+ } ) ;
2270
+
2271
+ // Take the outbound message and acknowledge it was sent
2272
+ if let Some ( fake:: OutgoingMessage ( _, meta) ) = chat_events. recv ( ) . await {
2273
+ inner_responses
2274
+ . send ( Outcome :: Continue ( MessageEvent :: SentMessage ( meta) ) . into ( ) )
2275
+ . expect ( "Should be able to send event" ) ;
2276
+ } else {
2277
+ panic ! ( "Failed to receive outbound message" ) ;
2278
+ }
2279
+
2280
+ // Now close the connection with the specified close code
2281
+ inner_responses
2282
+ . send (
2283
+ Outcome :: Finished ( Err ( NextEventError :: AbnormalServerClose {
2284
+ code : close_code,
2285
+ reason : format ! ( "Test close with code: {close_code}" ) ,
2286
+ } ) )
2287
+ . into ( ) ,
2288
+ )
2289
+ . expect ( "Should be able to send close event" ) ;
2290
+
2291
+ // Wait for the listener to receive the close event
2292
+ let mut received_close = false ;
2293
+ while let Some ( event) = listener_rx. recv ( ) . await {
2294
+ if let ListenerEvent :: Finished ( _) = event {
2295
+ received_close = true ;
2296
+ break ;
2297
+ }
2298
+ }
2299
+ assert ! (
2300
+ received_close,
2301
+ "Listener should have received a close event"
2302
+ ) ;
2303
+
2304
+ // Wait for the send task to complete and verify the error type
2305
+ let send_result = tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , send_task)
2306
+ . await
2307
+ . expect ( "send completes within timeout" )
2308
+ . expect ( "Task should not panic" ) ;
2309
+
2310
+ assert_eq ! ( send_result, Err ( expected_error) ) ;
2311
+ }
2138
2312
}
0 commit comments