Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 132 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub mod channel {

/// Oneshot channel, similar to tokio's oneshot channel
pub mod oneshot {
use std::{fmt::Debug, future::Future, io, pin::Pin, task};
use std::{fmt::Debug, future::Future, pin::Pin, task};

use n0_future::future::Boxed as BoxFuture;

Expand All @@ -162,7 +162,7 @@ pub mod channel {
/// overhead is negligible. However, boxing can also be used for local communication,
/// e.g. when applying a transform or filter to the message before sending it.
pub type BoxedSender<T> =
Box<dyn FnOnce(T) -> BoxFuture<io::Result<()>> + Send + Sync + 'static>;
Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;

/// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
///
Expand All @@ -172,7 +172,9 @@ pub mod channel {
/// Remote receivers are always boxed, since for remote communication the boxing
/// overhead is negligible. However, boxing can also be used for local communication,
/// e.g. when applying a transform or filter to the message before receiving it.
pub trait DynSender<T>: Future<Output = io::Result<()>> + Send + Sync + 'static {
pub trait DynSender<T>:
Future<Output = Result<(), SendError>> + Send + Sync + 'static
{
fn is_rpc(&self) -> bool;
}

Expand All @@ -181,7 +183,7 @@ pub mod channel {
/// Remote receivers are always boxed, since for remote communication the boxing
/// overhead is negligible. However, boxing can also be used for local communication,
/// e.g. when applying a transform or filter to the message before receiving it.
pub type BoxedReceiver<T> = BoxFuture<io::Result<T>>;
pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;

/// A oneshot sender.
///
Expand Down Expand Up @@ -230,7 +232,7 @@ pub mod channel {
pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
match self {
Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed),
Sender::Boxed(f) => f(value).await.map_err(SendError::from),
Sender::Boxed(f) => f(value).await,
}
}
}
Expand Down Expand Up @@ -266,7 +268,7 @@ pub mod channel {
fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
match self.get_mut() {
Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed),
Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io),
Self::Boxed(rx) => Pin::new(rx).poll(cx),
}
}
}
Expand All @@ -293,7 +295,7 @@ pub mod channel {
impl<T, F, Fut> From<F> for Receiver<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = io::Result<T>> + Send + 'static,
Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
{
fn from(f: F) -> Self {
Self::Boxed(Box::pin(f()))
Expand All @@ -317,7 +319,7 @@ pub mod channel {
///
/// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
pub mod mpsc {
use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc};
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};

use super::{RecvError, SendError};
use crate::RpcMessage;
Expand Down Expand Up @@ -398,7 +400,7 @@ pub mod channel {
fn send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>>;
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;

/// Try to send a message, returning as fast as possible if sending
/// is not currently possible.
Expand All @@ -408,7 +410,7 @@ pub mod channel {
fn try_send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>>;
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;

/// Await the sender close
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
Expand Down Expand Up @@ -458,7 +460,7 @@ pub mod channel {
Sender::Tokio(tx) => {
tx.send(value).await.map_err(|_| SendError::ReceiverClosed)
}
Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from),
Sender::Boxed(sink) => sink.send(value).await,
}
}

Expand Down Expand Up @@ -492,7 +494,7 @@ pub mod channel {
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
},
Sender::Boxed(sink) => sink.try_send(value).await.map_err(SendError::from),
Sender::Boxed(sink) => sink.try_send(value).await,
}
}
}
Expand Down Expand Up @@ -593,6 +595,9 @@ pub mod channel {
/// for local communication.
#[error("receiver closed")]
ReceiverClosed,
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// The underlying io error. This can occur for remote communication,
/// due to a network error or serialization error.
#[error("io error: {0}")]
Expand All @@ -603,6 +608,7 @@ pub mod channel {
fn from(e: SendError) -> Self {
match e {
SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
SendError::Io(e) => e,
}
}
Expand All @@ -619,6 +625,9 @@ pub mod channel {
/// for local communication.
#[error("sender closed")]
SenderClosed,
/// The message exceeded the maximum allowed message size [`MAX_MESSAGE_SIZE`].
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// An io error occurred. This can occur for remote communication,
/// due to a network error or deserialization error.
#[error("io error: {0}")]
Expand All @@ -630,6 +639,7 @@ pub mod channel {
match e {
RecvError::Io(e) => e,
RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
}
}
}
Expand Down Expand Up @@ -1126,28 +1136,66 @@ pub mod rpc {
RequestError, RpcMessage,
};

/// Default max message size (16 MiB).
const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;

/// Error code on streams if the max message size was exceeded.
const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;

/// Error code on streams if the sender tried to send an message that could not be postcard serialized.
const ERROR_CODE_INVALID_POSTCARD: u32 = 2;

/// Error that can occur when writing the initial message when doing a
/// cross-process RPC.
#[derive(Debug, thiserror::Error)]
pub enum WriteError {
/// Error writing to the stream with quinn
#[error("error writing to stream: {0}")]
Quinn(#[from] quinn::WriteError),
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// Generic IO error, e.g. when serializing the message or when using
/// other transports.
#[error("error serializing: {0}")]
Io(#[from] io::Error),
}

impl From<postcard::Error> for WriteError {
fn from(value: postcard::Error) -> Self {
Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
}
}

impl From<postcard::Error> for SendError {
fn from(value: postcard::Error) -> Self {
Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
}
}

impl From<WriteError> for io::Error {
fn from(e: WriteError) -> Self {
match e {
WriteError::Io(e) => e,
WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
WriteError::Quinn(e) => e.into(),
}
}
}

impl From<quinn::WriteError> for SendError {
fn from(err: quinn::WriteError) -> Self {
match err {
quinn::WriteError::Stopped(code)
if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
{
SendError::MaxMessageSizeExceeded
}
_ => SendError::Io(io::Error::from(err)),
}
}
}

/// Trait to abstract over a client connection to a remote service.
///
/// This isn't really that much abstracted, since the result of open_bi must
Expand Down Expand Up @@ -1256,6 +1304,9 @@ pub mod rpc {
{
let RemoteSender(mut send, recv, _) = self;
let msg = msg.into();
if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
return Err(WriteError::MaxMessageSizeExceeded);
}
let mut buf = SmallVec::<[u8; 128]>::new();
buf.write_length_prefixed(msg)?;
send.write_all(&buf).await?;
Expand All @@ -1266,17 +1317,24 @@ pub mod rpc {
impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
fn from(mut read: quinn::RecvStream) -> Self {
let fut = async move {
let size = read.read_varint_u64().await?.ok_or(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read size",
))?;
let size = read
.read_varint_u64()
.await?
.ok_or(RecvError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read size",
)))?;
if size > MAX_MESSAGE_SIZE {
read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
return Err(RecvError::MaxMessageSizeExceeded);
}
let rest = read
.read_to_end(size as usize)
.await
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let msg: T = postcard::from_bytes(&rest)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
io::Result::Ok(msg)
Ok(msg)
};
oneshot::Receiver::from(|| fut)
}
Expand Down Expand Up @@ -1309,11 +1367,30 @@ pub mod rpc {
fn from(mut writer: quinn::SendStream) -> Self {
oneshot::Sender::Boxed(Box::new(move |value| {
Box::pin(async move {
let size = match postcard::experimental::serialized_size(&value) {
Ok(size) => size,
Err(e) => {
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(SendError::Io(io::Error::new(
io::ErrorKind::InvalidData,
e,
)));
}
};
if size as u64 > MAX_MESSAGE_SIZE {
writer
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(SendError::MaxMessageSizeExceeded);
}
// write via a small buffer to avoid allocation for small values
let mut buf = SmallVec::<[u8; 128]>::new();
buf.write_length_prefixed(value)?;
if let Err(e) = buf.write_length_prefixed(value) {
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e.into());
}
writer.write_all(&buf).await?;
io::Result::Ok(())
Ok(())
})
}))
}
Expand Down Expand Up @@ -1353,6 +1430,12 @@ pub mod rpc {
let Some(size) = read.read_varint_u64().await? else {
return Ok(None);
};
if size > MAX_MESSAGE_SIZE {
self.recv
.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(RecvError::MaxMessageSizeExceeded);
}
let mut buf = vec![0; size as usize];
read.read_exact(&mut buf)
.await
Expand All @@ -1378,11 +1461,27 @@ pub mod rpc {
fn send(
&mut self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>> {
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
Box::pin(async {
let size = match postcard::experimental::serialized_size(&value) {
Ok(size) => size,
Err(e) => {
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
}
};
if size as u64 > MAX_MESSAGE_SIZE {
self.send
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(SendError::MaxMessageSizeExceeded);
}
let value = value;
self.buffer.clear();
self.buffer.write_length_prefixed(value)?;
if let Err(e) = self.buffer.write_length_prefixed(value) {
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e.into());
}
self.send.write_all(&self.buffer).await?;
self.buffer.clear();
Ok(())
Expand All @@ -1392,8 +1491,11 @@ pub mod rpc {
fn try_send(
&mut self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>> {
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
Box::pin(async {
if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
return Err(SendError::MaxMessageSizeExceeded);
}
// todo: move the non-async part out of the box. Will require a new return type.
let value = value;
self.buffer.clear();
Expand Down Expand Up @@ -1434,7 +1536,7 @@ pub mod rpc {
fn send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>> {
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
Box::pin(async {
let mut guard = self.0.lock().await;
let sender = std::mem::take(guard.deref_mut());
Expand All @@ -1446,15 +1548,17 @@ pub mod rpc {
}
res
}
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
QuinnSenderState::Closed => {
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
}
}
})
}

fn try_send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>> {
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
Box::pin(async {
let mut guard = self.0.lock().await;
let sender = std::mem::take(guard.deref_mut());
Expand All @@ -1466,7 +1570,9 @@ pub mod rpc {
}
res
}
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
QuinnSenderState::Closed => {
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
}
}
})
}
Expand Down
Loading