Skip to content

Commit 83e0f83

Browse files
authored
Merge pull request #30 from n0-computer/Frando/max-message-size
feat: add a max message size restriction
2 parents cc886f3 + d7ec2b3 commit 83e0f83

File tree

5 files changed

+523
-143
lines changed

5 files changed

+523
-143
lines changed

src/lib.rs

Lines changed: 132 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub mod channel {
141141

142142
/// Oneshot channel, similar to tokio's oneshot channel
143143
pub mod oneshot {
144-
use std::{fmt::Debug, future::Future, io, pin::Pin, task};
144+
use std::{fmt::Debug, future::Future, pin::Pin, task};
145145

146146
use n0_future::future::Boxed as BoxFuture;
147147

@@ -162,7 +162,7 @@ pub mod channel {
162162
/// overhead is negligible. However, boxing can also be used for local communication,
163163
/// e.g. when applying a transform or filter to the message before sending it.
164164
pub type BoxedSender<T> =
165-
Box<dyn FnOnce(T) -> BoxFuture<io::Result<()>> + Send + Sync + 'static>;
165+
Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
166166

167167
/// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
168168
///
@@ -172,7 +172,9 @@ pub mod channel {
172172
/// Remote receivers are always boxed, since for remote communication the boxing
173173
/// overhead is negligible. However, boxing can also be used for local communication,
174174
/// e.g. when applying a transform or filter to the message before receiving it.
175-
pub trait DynSender<T>: Future<Output = io::Result<()>> + Send + Sync + 'static {
175+
pub trait DynSender<T>:
176+
Future<Output = Result<(), SendError>> + Send + Sync + 'static
177+
{
176178
fn is_rpc(&self) -> bool;
177179
}
178180

@@ -181,7 +183,7 @@ pub mod channel {
181183
/// Remote receivers are always boxed, since for remote communication the boxing
182184
/// overhead is negligible. However, boxing can also be used for local communication,
183185
/// e.g. when applying a transform or filter to the message before receiving it.
184-
pub type BoxedReceiver<T> = BoxFuture<io::Result<T>>;
186+
pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
185187

186188
/// A oneshot sender.
187189
///
@@ -230,7 +232,7 @@ pub mod channel {
230232
pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
231233
match self {
232234
Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed),
233-
Sender::Boxed(f) => f(value).await.map_err(SendError::from),
235+
Sender::Boxed(f) => f(value).await,
234236
}
235237
}
236238
}
@@ -266,7 +268,7 @@ pub mod channel {
266268
fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
267269
match self.get_mut() {
268270
Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed),
269-
Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io),
271+
Self::Boxed(rx) => Pin::new(rx).poll(cx),
270272
}
271273
}
272274
}
@@ -293,7 +295,7 @@ pub mod channel {
293295
impl<T, F, Fut> From<F> for Receiver<T>
294296
where
295297
F: FnOnce() -> Fut,
296-
Fut: Future<Output = io::Result<T>> + Send + 'static,
298+
Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
297299
{
298300
fn from(f: F) -> Self {
299301
Self::Boxed(Box::pin(f()))
@@ -317,7 +319,7 @@ pub mod channel {
317319
///
318320
/// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
319321
pub mod mpsc {
320-
use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc};
322+
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};
321323

322324
use super::{RecvError, SendError};
323325
use crate::RpcMessage;
@@ -398,7 +400,7 @@ pub mod channel {
398400
fn send(
399401
&self,
400402
value: T,
401-
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>>;
403+
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
402404

403405
/// Try to send a message, returning as fast as possible if sending
404406
/// is not currently possible.
@@ -408,7 +410,7 @@ pub mod channel {
408410
fn try_send(
409411
&self,
410412
value: T,
411-
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>>;
413+
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
412414

413415
/// Await the sender close
414416
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
@@ -458,7 +460,7 @@ pub mod channel {
458460
Sender::Tokio(tx) => {
459461
tx.send(value).await.map_err(|_| SendError::ReceiverClosed)
460462
}
461-
Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from),
463+
Sender::Boxed(sink) => sink.send(value).await,
462464
}
463465
}
464466

@@ -492,7 +494,7 @@ pub mod channel {
492494
}
493495
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
494496
},
495-
Sender::Boxed(sink) => sink.try_send(value).await.map_err(SendError::from),
497+
Sender::Boxed(sink) => sink.try_send(value).await,
496498
}
497499
}
498500
}
@@ -593,6 +595,9 @@ pub mod channel {
593595
/// for local communication.
594596
#[error("receiver closed")]
595597
ReceiverClosed,
598+
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
599+
#[error("maximum message size exceeded")]
600+
MaxMessageSizeExceeded,
596601
/// The underlying io error. This can occur for remote communication,
597602
/// due to a network error or serialization error.
598603
#[error("io error: {0}")]
@@ -603,6 +608,7 @@ pub mod channel {
603608
fn from(e: SendError) -> Self {
604609
match e {
605610
SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
611+
SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
606612
SendError::Io(e) => e,
607613
}
608614
}
@@ -619,6 +625,9 @@ pub mod channel {
619625
/// for local communication.
620626
#[error("sender closed")]
621627
SenderClosed,
628+
/// The message exceeded the maximum allowed message size [`MAX_MESSAGE_SIZE`].
629+
#[error("maximum message size exceeded")]
630+
MaxMessageSizeExceeded,
622631
/// An io error occurred. This can occur for remote communication,
623632
/// due to a network error or deserialization error.
624633
#[error("io error: {0}")]
@@ -630,6 +639,7 @@ pub mod channel {
630639
match e {
631640
RecvError::Io(e) => e,
632641
RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
642+
RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
633643
}
634644
}
635645
}
@@ -1126,28 +1136,66 @@ pub mod rpc {
11261136
RequestError, RpcMessage,
11271137
};
11281138

1139+
/// Default max message size (16 MiB).
1140+
const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1141+
1142+
/// Error code on streams if the max message size was exceeded.
1143+
const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1144+
1145+
/// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1146+
const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1147+
11291148
/// Error that can occur when writing the initial message when doing a
11301149
/// cross-process RPC.
11311150
#[derive(Debug, thiserror::Error)]
11321151
pub enum WriteError {
11331152
/// Error writing to the stream with quinn
11341153
#[error("error writing to stream: {0}")]
11351154
Quinn(#[from] quinn::WriteError),
1155+
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1156+
#[error("maximum message size exceeded")]
1157+
MaxMessageSizeExceeded,
11361158
/// Generic IO error, e.g. when serializing the message or when using
11371159
/// other transports.
11381160
#[error("error serializing: {0}")]
11391161
Io(#[from] io::Error),
11401162
}
11411163

1164+
impl From<postcard::Error> for WriteError {
1165+
fn from(value: postcard::Error) -> Self {
1166+
Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1167+
}
1168+
}
1169+
1170+
impl From<postcard::Error> for SendError {
1171+
fn from(value: postcard::Error) -> Self {
1172+
Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1173+
}
1174+
}
1175+
11421176
impl From<WriteError> for io::Error {
11431177
fn from(e: WriteError) -> Self {
11441178
match e {
11451179
WriteError::Io(e) => e,
1180+
WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
11461181
WriteError::Quinn(e) => e.into(),
11471182
}
11481183
}
11491184
}
11501185

1186+
impl From<quinn::WriteError> for SendError {
1187+
fn from(err: quinn::WriteError) -> Self {
1188+
match err {
1189+
quinn::WriteError::Stopped(code)
1190+
if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1191+
{
1192+
SendError::MaxMessageSizeExceeded
1193+
}
1194+
_ => SendError::Io(io::Error::from(err)),
1195+
}
1196+
}
1197+
}
1198+
11511199
/// Trait to abstract over a client connection to a remote service.
11521200
///
11531201
/// This isn't really that much abstracted, since the result of open_bi must
@@ -1256,6 +1304,9 @@ pub mod rpc {
12561304
{
12571305
let RemoteSender(mut send, recv, _) = self;
12581306
let msg = msg.into();
1307+
if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
1308+
return Err(WriteError::MaxMessageSizeExceeded);
1309+
}
12591310
let mut buf = SmallVec::<[u8; 128]>::new();
12601311
buf.write_length_prefixed(msg)?;
12611312
send.write_all(&buf).await?;
@@ -1266,17 +1317,24 @@ pub mod rpc {
12661317
impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
12671318
fn from(mut read: quinn::RecvStream) -> Self {
12681319
let fut = async move {
1269-
let size = read.read_varint_u64().await?.ok_or(io::Error::new(
1270-
io::ErrorKind::UnexpectedEof,
1271-
"failed to read size",
1272-
))?;
1320+
let size = read
1321+
.read_varint_u64()
1322+
.await?
1323+
.ok_or(RecvError::Io(io::Error::new(
1324+
io::ErrorKind::UnexpectedEof,
1325+
"failed to read size",
1326+
)))?;
1327+
if size > MAX_MESSAGE_SIZE {
1328+
read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
1329+
return Err(RecvError::MaxMessageSizeExceeded);
1330+
}
12731331
let rest = read
12741332
.read_to_end(size as usize)
12751333
.await
12761334
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
12771335
let msg: T = postcard::from_bytes(&rest)
12781336
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1279-
io::Result::Ok(msg)
1337+
Ok(msg)
12801338
};
12811339
oneshot::Receiver::from(|| fut)
12821340
}
@@ -1309,11 +1367,30 @@ pub mod rpc {
13091367
fn from(mut writer: quinn::SendStream) -> Self {
13101368
oneshot::Sender::Boxed(Box::new(move |value| {
13111369
Box::pin(async move {
1370+
let size = match postcard::experimental::serialized_size(&value) {
1371+
Ok(size) => size,
1372+
Err(e) => {
1373+
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
1374+
return Err(SendError::Io(io::Error::new(
1375+
io::ErrorKind::InvalidData,
1376+
e,
1377+
)));
1378+
}
1379+
};
1380+
if size as u64 > MAX_MESSAGE_SIZE {
1381+
writer
1382+
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
1383+
.ok();
1384+
return Err(SendError::MaxMessageSizeExceeded);
1385+
}
13121386
// write via a small buffer to avoid allocation for small values
13131387
let mut buf = SmallVec::<[u8; 128]>::new();
1314-
buf.write_length_prefixed(value)?;
1388+
if let Err(e) = buf.write_length_prefixed(value) {
1389+
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
1390+
return Err(e.into());
1391+
}
13151392
writer.write_all(&buf).await?;
1316-
io::Result::Ok(())
1393+
Ok(())
13171394
})
13181395
}))
13191396
}
@@ -1353,6 +1430,12 @@ pub mod rpc {
13531430
let Some(size) = read.read_varint_u64().await? else {
13541431
return Ok(None);
13551432
};
1433+
if size > MAX_MESSAGE_SIZE {
1434+
self.recv
1435+
.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
1436+
.ok();
1437+
return Err(RecvError::MaxMessageSizeExceeded);
1438+
}
13561439
let mut buf = vec![0; size as usize];
13571440
read.read_exact(&mut buf)
13581441
.await
@@ -1378,11 +1461,27 @@ pub mod rpc {
13781461
fn send(
13791462
&mut self,
13801463
value: T,
1381-
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>> {
1464+
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
13821465
Box::pin(async {
1466+
let size = match postcard::experimental::serialized_size(&value) {
1467+
Ok(size) => size,
1468+
Err(e) => {
1469+
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
1470+
return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
1471+
}
1472+
};
1473+
if size as u64 > MAX_MESSAGE_SIZE {
1474+
self.send
1475+
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
1476+
.ok();
1477+
return Err(SendError::MaxMessageSizeExceeded);
1478+
}
13831479
let value = value;
13841480
self.buffer.clear();
1385-
self.buffer.write_length_prefixed(value)?;
1481+
if let Err(e) = self.buffer.write_length_prefixed(value) {
1482+
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
1483+
return Err(e.into());
1484+
}
13861485
self.send.write_all(&self.buffer).await?;
13871486
self.buffer.clear();
13881487
Ok(())
@@ -1392,8 +1491,11 @@ pub mod rpc {
13921491
fn try_send(
13931492
&mut self,
13941493
value: T,
1395-
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>> {
1494+
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
13961495
Box::pin(async {
1496+
if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
1497+
return Err(SendError::MaxMessageSizeExceeded);
1498+
}
13971499
// todo: move the non-async part out of the box. Will require a new return type.
13981500
let value = value;
13991501
self.buffer.clear();
@@ -1434,7 +1536,7 @@ pub mod rpc {
14341536
fn send(
14351537
&self,
14361538
value: T,
1437-
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>> {
1539+
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
14381540
Box::pin(async {
14391541
let mut guard = self.0.lock().await;
14401542
let sender = std::mem::take(guard.deref_mut());
@@ -1446,15 +1548,17 @@ pub mod rpc {
14461548
}
14471549
res
14481550
}
1449-
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
1551+
QuinnSenderState::Closed => {
1552+
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
1553+
}
14501554
}
14511555
})
14521556
}
14531557

14541558
fn try_send(
14551559
&self,
14561560
value: T,
1457-
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>> {
1561+
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
14581562
Box::pin(async {
14591563
let mut guard = self.0.lock().await;
14601564
let sender = std::mem::take(guard.deref_mut());
@@ -1466,7 +1570,9 @@ pub mod rpc {
14661570
}
14671571
res
14681572
}
1469-
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
1573+
QuinnSenderState::Closed => {
1574+
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
1575+
}
14701576
}
14711577
})
14721578
}

0 commit comments

Comments
 (0)