Skip to content

Commit d557675

Browse files
Return specific SendError when request fails due to underlying transport failure
1 parent ec49709 commit d557675

File tree

2 files changed

+231
-55
lines changed

2 files changed

+231
-55
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ v0.71.0
1212

1313
- net: Connections to Signal services (and to Cloudflare's DNS-over-HTTPS server) will now require TLS v1.3, which they would already have been using.
1414

15+
- net: Futures returned by ChatConnection.send() will now return more specific errors on failure
16+
1517
- New SVR2 enclaves for staging and production.

rust/net/src/chat/ws2.rs

Lines changed: 229 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,7 @@ pub enum ListenerEvent {
9999
#[cfg_attr(test, derive(PartialEq))]
100100
pub enum SendError {
101101
/// the chat service is no longer connected
102-
Disconnected {
103-
#[cfg(test)] // Useful for testing but otherwise unused
104-
reason: &'static str,
105-
},
102+
Disconnected(DisconnectedReason),
106103
/// an OS-level I/O error occurred
107104
Io(IoErrorKind),
108105
/// the message is larger than the configured limit
@@ -115,6 +112,20 @@ pub enum SendError {
115112
InvalidRequest(InvalidRequestError),
116113
}
117114

115+
#[derive(Debug)]
116+
#[cfg_attr(test, derive(PartialEq))]
117+
pub enum DisconnectedReason {
118+
/// the server explicitly disconnected us because we connected elsewhere with the same credentials
119+
ConnectedElsewhere,
120+
/// the server has disconnect us because the credentials we used to connect have become invalidated
121+
ConnectionInvalidated,
122+
// the socket was closed, either by us or by the server, for some other reason.
123+
SocketClosed {
124+
#[cfg(test)] // Useful for testing but otherwise unused
125+
reason: &'static str,
126+
},
127+
}
128+
118129
#[derive(Debug)]
119130
#[cfg_attr(test, derive(PartialEq))]
120131
pub enum InvalidRequestError {
@@ -378,10 +389,10 @@ impl Responder {
378389
}
379390
}
380391

381-
Err(SendError::Disconnected {
392+
Err(SendError::Disconnected(DisconnectedReason::SocketClosed {
382393
#[cfg(test)]
383394
reason: "task exited without receiving response",
384-
})
395+
}))
385396
}
386397
}
387398

@@ -668,6 +679,29 @@ async fn spawned_task_body<I: InnerConnection>(
668679
task_result
669680
}
670681

682+
/// Retrieves the final task error state and converts it to a `SendError`.
683+
///
684+
/// This function waits for the task to finish if it hasn't already, and then
685+
/// extracts the error state. It should only be called when we know the task is
686+
/// ending or has already ended, such as when a send operation to the task has
687+
/// failed, or it will hold the state lock for an unboundedly long time.
688+
async fn get_task_finish_error(state: &TokioMutex<TaskState>, _reason: &'static str) -> SendError {
689+
// We're holding the lock here across an await point to prevent
690+
// another method from also trying to wait for the task result and
691+
// update state. Since the earlier send failed, the task must have
692+
// dropped its receiver, and it doesn't do much after that so this
693+
// should be a short wait.
694+
let mut guard = state.lock().await;
695+
let finished_state = wait_for_task_to_finish(&mut guard).await.as_ref();
696+
match finished_state {
697+
Ok(_) => SendError::Disconnected(DisconnectedReason::SocketClosed {
698+
#[cfg(test)]
699+
reason: _reason,
700+
}),
701+
Err(err) => SendError::from(err),
702+
}
703+
}
704+
671705
async fn send_request(
672706
state: &TokioMutex<TaskState>,
673707
request: PartialRequestProto,
@@ -682,16 +716,16 @@ async fn send_request(
682716
task: _,
683717
} => request_tx.clone(),
684718
TaskState::SignaledToEnd(_) => {
685-
return Err(SendError::Disconnected {
719+
return Err(SendError::Disconnected(DisconnectedReason::SocketClosed {
686720
#[cfg(test)]
687721
reason: "task was already signalled to end",
688-
})
722+
}))
689723
}
690724
TaskState::Finished(Ok(_reason)) => {
691-
return Err(SendError::Disconnected {
725+
return Err(SendError::Disconnected(DisconnectedReason::SocketClosed {
692726
#[cfg(test)]
693727
reason: "task already ended gracefully",
694-
})
728+
}))
695729
}
696730
TaskState::Finished(Err(err)) => return Err(SendError::from(&*err)),
697731
}
@@ -708,36 +742,21 @@ async fn send_request(
708742
.is_ok()
709743
{
710744
// The request was sent, now wait for the response to be sent back.
711-
let response =
712-
receiver
713-
.await
714-
.map_err(|_: oneshot::error::RecvError| SendError::Disconnected {
715-
#[cfg(test)]
716-
reason: "response channel sender was dropped",
717-
})?;
718-
response.map_err(SendError::from)
719-
} else {
720-
// The request couldn't be sent to the task. We could give up now
721-
// and return SendError::Disconnected but that's not as useful as
722-
// something derived from the actual end status.
723-
let mut guard = state.lock().await;
724-
725-
// We're holding the lock here across an await point to prevent
726-
// another method from also trying to wait for the task result and
727-
// update state. Since the earlier send failed, the task must have
728-
// dropped its receiver, and it doesn't do much after that so this
729-
// should be a short wait.
730-
let finished_state = wait_for_task_to_finish(&mut guard).await.as_ref();
731-
732-
let send_error = finished_state.map_or_else(SendError::from, |_reason| {
733-
// The task exited successfully but our send still didn't go
734-
// through, so return an error.
735-
SendError::Disconnected {
736-
#[cfg(test)]
737-
reason: "task ended gracefully before sending request",
745+
match receiver.await {
746+
Ok(response) => response.map_err(SendError::from),
747+
Err(_) => {
748+
// The sender was dropped without sending a response.
749+
// This happens when the connection is closed while our request is in flight.
750+
// Fetch the reason for the underlying connection failure, and return that as the
751+
// reason for the request failure, to be most useful.
752+
Err(get_task_finish_error(state, "response channel sender was dropped").await)
738753
}
739-
});
740-
Err(send_error)
754+
}
755+
} else {
756+
// We could not send the request at all, so the task must have ended, probably due to the connection
757+
// closing. Fetch the reason for the underlying connection failure, and return that as the reason for
758+
// the request failure.
759+
Err(get_task_finish_error(state, "task ended gracefully before sending request").await)
741760
}
742761
}
743762

@@ -1081,23 +1100,37 @@ pub(super) fn decode_and_validate(data: &[u8]) -> Result<ChatMessageProto, ChatP
10811100

10821101
impl From<&TaskErrorState> for SendError {
10831102
fn from(value: &TaskErrorState) -> Self {
1084-
let _ = value;
1085-
SendError::Disconnected {
1086-
#[cfg(test)]
1087-
reason: match value {
1088-
TaskErrorState::SendFailed => "send failed",
1089-
TaskErrorState::Panic(_) => "chat task panicked",
1090-
TaskErrorState::AbnormalServerClose { .. } => "server closed abnormally",
1091-
TaskErrorState::ReceiveFailed => "receive failed",
1092-
TaskErrorState::ServerIdleTooLong(_) => "server idle too long",
1093-
TaskErrorState::UnexpectedConnectionClose => "server closed unexpectedly",
1103+
match value {
1104+
TaskErrorState::AbnormalServerClose { code, reason: _ } => match code {
1105+
CloseCode::Library(CONNECTED_ELSEWHERE_CLOSE_CODE) => {
1106+
SendError::Disconnected(DisconnectedReason::ConnectedElsewhere)
1107+
}
1108+
CloseCode::Library(CONNECTION_INVALIDATED_CLOSE_CODE) => {
1109+
SendError::Disconnected(DisconnectedReason::ConnectionInvalidated)
1110+
}
1111+
_ => SendError::Disconnected(DisconnectedReason::SocketClosed {
1112+
#[cfg(test)]
1113+
reason: "server closed abnormally",
1114+
}),
10941115
},
1116+
_ => SendError::Disconnected(DisconnectedReason::SocketClosed {
1117+
#[cfg(test)]
1118+
reason: match value {
1119+
TaskErrorState::SendFailed => "send failed",
1120+
TaskErrorState::Panic(_) => "chat task panicked",
1121+
// Already handled above, this is test-only code so fail-fast is desirable.
1122+
TaskErrorState::AbnormalServerClose { .. } => unreachable!(),
1123+
TaskErrorState::ReceiveFailed => "receive failed",
1124+
TaskErrorState::ServerIdleTooLong(_) => "server idle too long",
1125+
TaskErrorState::UnexpectedConnectionClose => "server closed unexpectedly",
1126+
},
1127+
}),
10951128
}
10961129
}
10971130
}
10981131

10991132
impl From<TaskSendError> for SendError {
1100-
fn from(value: TaskSendError) -> SendError {
1133+
fn from(value: TaskSendError) -> Self {
11011134
match value {
11021135
TaskSendError::StreamSendFailed(send_error) => send_error.into(),
11031136
TaskSendError::InvalidResponse => SendError::InvalidResponse,
@@ -1135,10 +1168,12 @@ impl From<&TungsteniteSendError> for SendError {
11351168
fn from(value: &TungsteniteSendError) -> Self {
11361169
match value {
11371170
TungsteniteSendError::Io(io) => SendError::Io(io.kind()),
1138-
TungsteniteSendError::ConnectionAlreadyClosed => SendError::Disconnected {
1139-
#[cfg(test)]
1140-
reason: "task failure due to send failure",
1141-
},
1171+
TungsteniteSendError::ConnectionAlreadyClosed => {
1172+
SendError::Disconnected(DisconnectedReason::SocketClosed {
1173+
#[cfg(test)]
1174+
reason: "task failure due to send failure",
1175+
})
1176+
}
11421177
TungsteniteSendError::MessageTooLarge { size, max_size } => {
11431178
SendError::MessageTooLarge {
11441179
size: *size,
@@ -1189,7 +1224,13 @@ impl From<TaskExitError> for crate::chat::SendError {
11891224
impl From<SendError> for super::SendError {
11901225
fn from(value: SendError) -> Self {
11911226
match value {
1192-
SendError::Disconnected { .. } => Self::Disconnected,
1227+
SendError::Disconnected(DisconnectedReason::SocketClosed { .. }) => Self::Disconnected,
1228+
SendError::Disconnected(DisconnectedReason::ConnectedElsewhere) => {
1229+
Self::ConnectedElsewhere
1230+
}
1231+
SendError::Disconnected(DisconnectedReason::ConnectionInvalidated) => {
1232+
Self::ConnectionInvalidated
1233+
}
11931234
SendError::Io(error_kind) => {
11941235
Self::WebSocket(WebSocketServiceError::Io(error_kind.into()))
11951236
}
@@ -2135,4 +2176,137 @@ mod test {
21352176
);
21362177
assert_matches!(listener_rx.try_recv(), Err(TryRecvError::Empty));
21372178
}
2179+
2180+
#[test_case(
2181+
CloseCode::from(CONNECTION_INVALIDATED_CLOSE_CODE), SendError::Disconnected(DisconnectedReason::ConnectionInvalidated);
2182+
"CONNECTION_INVALIDATED_CLOSE_CODE results in ConnectionInvalidated"
2183+
)]
2184+
#[test_case(
2185+
CloseCode::from(CONNECTED_ELSEWHERE_CLOSE_CODE), SendError::Disconnected(DisconnectedReason::ConnectedElsewhere);
2186+
"CONNECTED_ELSEWHERE_CLOSE_CODE results in ConnectedElsewhere"
2187+
)]
2188+
#[test_case(
2189+
CloseCode::Normal, SendError::Disconnected(DisconnectedReason::SocketClosed { #[cfg(test)] reason: "server closed abnormally" });
2190+
"Normal close results in Disconnected"
2191+
)]
2192+
#[test_log::test(tokio::test(start_paused = true))]
2193+
async fn send_after_ws_close_returns_proper_error(
2194+
close_code: CloseCode,
2195+
expected_error: SendError,
2196+
) {
2197+
let (chat, (_inner_events, inner_responses)) = fake::new_chat(Box::new(|_| ()));
2198+
assert!(chat.is_connected().await);
2199+
2200+
// Close the connection with the specific close code
2201+
inner_responses
2202+
.send(
2203+
Outcome::Finished(Err(NextEventError::AbnormalServerClose {
2204+
code: close_code,
2205+
reason: format!("close code: {close_code}"),
2206+
}))
2207+
.into(),
2208+
)
2209+
.expect("can send close event");
2210+
2211+
let wait_for_disconnect = async {
2212+
while chat.is_connected().await {
2213+
tokio::task::yield_now().await;
2214+
}
2215+
};
2216+
2217+
tokio::time::timeout(Duration::from_secs(1), wait_for_disconnect)
2218+
.await
2219+
.expect("chat disconnect does not take long");
2220+
2221+
// Try to send a request, which should fail with the expected error
2222+
let send_result = chat
2223+
.send(Request {
2224+
method: Method::GET,
2225+
path: PathAndQuery::from_static("/test"),
2226+
headers: HeaderMap::default(),
2227+
body: None,
2228+
})
2229+
.await;
2230+
2231+
assert_eq!(send_result, Err(expected_error));
2232+
}
2233+
2234+
#[test_case(
2235+
CloseCode::from(CONNECTION_INVALIDATED_CLOSE_CODE),
2236+
SendError::Disconnected(DisconnectedReason::ConnectionInvalidated);
2237+
"CONNECTION_INVALIDATED_CLOSE_CODE should propagate correctly"
2238+
)]
2239+
#[test_case(
2240+
CloseCode::from(CONNECTED_ELSEWHERE_CLOSE_CODE),
2241+
SendError::Disconnected(DisconnectedReason::ConnectedElsewhere);
2242+
"CONNECTED_ELSEWHERE_CLOSE_CODE should propagate correctly"
2243+
)]
2244+
#[test_case(
2245+
CloseCode::Normal,
2246+
SendError::Disconnected(DisconnectedReason::SocketClosed { #[cfg(test)] reason: "server closed abnormally" });
2247+
"Normal close results in Disconnected"
2248+
)]
2249+
#[test_log::test(tokio::test(start_paused = true))]
2250+
async fn connection_close_with_in_flight_request(
2251+
close_code: CloseCode,
2252+
expected_error: SendError,
2253+
) {
2254+
// Create channels for listener events
2255+
let (listener_tx, mut listener_rx) = mpsc::unbounded_channel();
2256+
let (chat, (mut chat_events, inner_responses)) = fake::new_chat(Box::new(move |evt| {
2257+
let _ = listener_tx.send(evt);
2258+
}));
2259+
2260+
// Start a request but don't complete it right away
2261+
let send_task = tokio::spawn(async move {
2262+
chat.send(Request {
2263+
method: Method::GET,
2264+
path: PathAndQuery::from_static("/test"),
2265+
headers: HeaderMap::default(),
2266+
body: None,
2267+
})
2268+
.await
2269+
});
2270+
2271+
// Take the outbound message and acknowledge it was sent
2272+
if let Some(fake::OutgoingMessage(_, meta)) = chat_events.recv().await {
2273+
inner_responses
2274+
.send(Outcome::Continue(MessageEvent::SentMessage(meta)).into())
2275+
.expect("Should be able to send event");
2276+
} else {
2277+
panic!("Failed to receive outbound message");
2278+
}
2279+
2280+
// Now close the connection with the specified close code
2281+
inner_responses
2282+
.send(
2283+
Outcome::Finished(Err(NextEventError::AbnormalServerClose {
2284+
code: close_code,
2285+
reason: format!("Test close with code: {close_code}"),
2286+
}))
2287+
.into(),
2288+
)
2289+
.expect("Should be able to send close event");
2290+
2291+
// Wait for the listener to receive the close event
2292+
let mut received_close = false;
2293+
while let Some(event) = listener_rx.recv().await {
2294+
if let ListenerEvent::Finished(_) = event {
2295+
received_close = true;
2296+
break;
2297+
}
2298+
}
2299+
assert!(
2300+
received_close,
2301+
"Listener should have received a close event"
2302+
);
2303+
2304+
// Wait for the send task to complete and verify the error type
2305+
let send_result = tokio::time::timeout(Duration::from_secs(1), send_task)
2306+
.await
2307+
.expect("send completes within timeout")
2308+
.expect("Task should not panic");
2309+
2310+
assert_eq!(send_result, Err(expected_error));
2311+
}
21382312
}

0 commit comments

Comments
 (0)