Skip to content

Commit e8cc85a

Browse files
committed
fix: harden listen loops
1 parent 64ad596 commit e8cc85a

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

irpc-iroh/src/lib.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use irpc::{
1818
};
1919
use n0_future::{future::Boxed as BoxFuture, TryFutureExt};
2020
use serde::de::DeserializeOwned;
21-
use tracing::{trace, trace_span, warn, Instrument};
21+
use tracing::{debug, error_span, trace, trace_span, warn, Instrument};
2222

2323
/// Returns a client that connects to a irpc service using an [`iroh::Endpoint`].
2424
pub fn client<S: irpc::Service>(
@@ -207,6 +207,10 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(
207207
connection: Connection,
208208
handler: Handler<R>,
209209
) -> io::Result<()> {
210+
if let Ok(remote) = connection.remote_node_id() {
211+
tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short()));
212+
}
213+
debug!("connection accepted");
210214
loop {
211215
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
212216
return Ok(());
@@ -270,19 +274,32 @@ pub async fn read_request_raw<R: DeserializeOwned + 'static>(
270274
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
271275
let mut request_id = 0u64;
272276
let mut tasks = n0_future::task::JoinSet::new();
273-
while let Some(incoming) = endpoint.accept().await {
277+
loop {
278+
let incoming = tokio::select! {
279+
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
280+
res.expect("irpc connection task panicked");
281+
continue;
282+
}
283+
incoming = endpoint.accept() => {
284+
match incoming {
285+
None => break,
286+
Some(incoming) => incoming
287+
}
288+
}
289+
};
274290
let handler = handler.clone();
275291
let fut = async move {
276-
let connection = match incoming.await {
277-
Ok(connection) => connection,
292+
match incoming.await {
293+
Ok(connection) => match handle_connection(connection, handler).await {
294+
Err(err) => warn!("connection closed with error: {err:?}"),
295+
Ok(()) => debug!("connection closed"),
296+
},
278297
Err(cause) => {
279-
warn!("failed to accept connection {cause:?}");
280-
return io::Result::Ok(());
298+
warn!("failed to accept connection: {cause:?}");
281299
}
282300
};
283-
handle_connection(connection, handler).await
284301
};
285-
let span = trace_span!("rpc", id = request_id);
302+
let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
286303
tasks.spawn(fut.instrument(span));
287304
request_id += 1;
288305
}

src/lib.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,7 +1534,7 @@ pub mod rpc {
15341534
use quinn::ConnectionError;
15351535
use serde::de::DeserializeOwned;
15361536
use smallvec::SmallVec;
1537-
use tracing::{trace, trace_span, warn, Instrument};
1537+
use tracing::{debug, error_span, trace, warn, Instrument};
15381538

15391539
use crate::{
15401540
channel::{
@@ -2054,19 +2054,32 @@ pub mod rpc {
20542054
) {
20552055
let mut request_id = 0u64;
20562056
let mut tasks = JoinSet::new();
2057-
while let Some(incoming) = endpoint.accept().await {
2057+
loop {
2058+
let incoming = tokio::select! {
2059+
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2060+
res.expect("irpc connection task panicked");
2061+
continue;
2062+
}
2063+
incoming = endpoint.accept() => {
2064+
match incoming {
2065+
None => break,
2066+
Some(incoming) => incoming
2067+
}
2068+
}
2069+
};
20582070
let handler = handler.clone();
20592071
let fut = async move {
2060-
let connection = match incoming.await {
2061-
Ok(connection) => connection,
2072+
match incoming.await {
2073+
Ok(connection) => match handle_connection(connection, handler).await {
2074+
Err(err) => warn!("connection closed with error: {err:?}"),
2075+
Ok(()) => debug!("connection closed"),
2076+
},
20622077
Err(cause) => {
2063-
warn!("failed to accept connection {cause:?}");
2064-
return io::Result::Ok(());
2078+
warn!("failed to accept connection: {cause:?}");
20652079
}
20662080
};
2067-
handle_connection(connection, handler).await
20682081
};
2069-
let span = trace_span!("rpc", id = request_id);
2082+
let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
20702083
tasks.spawn(fut.instrument(span));
20712084
request_id += 1;
20722085
}
@@ -2077,6 +2090,11 @@ pub mod rpc {
20772090
connection: quinn::Connection,
20782091
handler: Handler<R>,
20792092
) -> io::Result<()> {
2093+
tracing::Span::current().record(
2094+
"remote",
2095+
tracing::field::display(connection.remote_address()),
2096+
);
2097+
debug!("connection accepted");
20802098
loop {
20812099
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
20822100
return Ok(());

0 commit comments

Comments
 (0)