Skip to content

Commit 2e27966

Browse files
Add write buffer to Noise stream
1 parent 5a99532 commit 2e27966

File tree

2 files changed

+184
-32
lines changed

2 files changed

+184
-32
lines changed

rust/attest/src/client_connection.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ pub const NOISE_PATTERN_HFS: &str = "Noise_NKhfs_25519+Kyber1024_ChaChaPoly_SHA2
1515

1616
pub(crate) const NOISE_HANDSHAKE_OVERHEAD: usize = 64 + /* post-quantum kyber1024: */ 1568;
1717

18-
pub(crate) const NOISE_TRANSPORT_PER_PACKET_MAX: usize = 65535;
18+
pub const NOISE_TRANSPORT_PER_PACKET_MAX: usize = 65535;
1919
pub(crate) const NOISE_TRANSPORT_PER_PAYLOAD_OVERHEAD: usize = 16;
20-
pub(crate) const NOISE_TRANSPORT_PER_PAYLOAD_MAX: usize =
20+
pub const NOISE_TRANSPORT_PER_PAYLOAD_MAX: usize =
2121
NOISE_TRANSPORT_PER_PACKET_MAX - NOISE_TRANSPORT_PER_PAYLOAD_OVERHEAD;
2222

2323
#[derive(Debug)]

rust/net/infra/src/noise/stream.rs

Lines changed: 182 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::io::Error as IoError;
77
use std::pin::Pin;
88
use std::task::{ready, Context, Poll};
99

10-
use attest::client_connection::ClientConnection;
10+
use attest::client_connection::{ClientConnection, NOISE_TRANSPORT_PER_PAYLOAD_MAX};
1111
use bytes::Bytes;
1212
use futures_util::stream::FusedStream;
1313
use futures_util::{SinkExt as _, StreamExt as _};
@@ -33,7 +33,7 @@ pub struct NoiseStream<S> {
3333

3434
#[derive(Debug, Default)]
3535
struct Write {
36-
buffer_policy: WriteBufferPolicy,
36+
buffer: WriteBuffer,
3737
}
3838

3939
#[derive(Debug, Default)]
@@ -43,10 +43,10 @@ enum Read {
4343
ReadFromBlock(Bytes),
4444
}
4545

46-
#[derive(Debug, Default)]
47-
enum WriteBufferPolicy {
48-
#[default]
49-
NoBuffering,
46+
#[derive(Debug)]
47+
struct WriteBuffer {
48+
length: u16,
49+
bytes: Box<[u8; NOISE_TRANSPORT_PER_PAYLOAD_MAX]>,
5050
}
5151

5252
impl<S> NoiseStream<S> {
@@ -132,43 +132,126 @@ impl<S: Transport + Unpin> AsyncWrite for NoiseStream<S> {
132132
let Self {
133133
transport,
134134
inner,
135-
write,
135+
write: Write { buffer },
136+
read: _,
137+
} = self.get_mut();
138+
139+
let bytes_remaining = buffer.bytes.len() - usize::from(buffer.length);
140+
141+
if bytes_remaining == 0 {
142+
// We need to make space by flushing the contents of the buffer.
143+
let () = ready!(buffer.poll_flush(ptr, cx, transport, inner))?;
144+
145+
debug_assert_eq!(buffer.length, 0);
146+
}
147+
148+
let count = buffer.copy_prefix(buf);
149+
log::trace!("{ptr:x?} buffered {count} bytes");
150+
Poll::Ready(Ok(count))
151+
}
152+
153+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
154+
let ptr = &*self as *const Self;
155+
156+
let Self {
157+
transport,
158+
inner,
159+
write: Write { buffer },
160+
read: _,
161+
} = self.get_mut();
162+
163+
if buffer.length != 0 {
164+
log::trace!("{ptr:x?} trying to flush write buffer");
165+
let () = ready!(buffer.poll_flush(ptr, cx, transport, inner))?;
166+
167+
debug_assert_eq!(buffer.length, 0);
168+
}
169+
170+
inner.poll_flush_unpin(cx)
171+
}
172+
173+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
174+
let ptr = &*self as *const Self;
175+
176+
let Self {
177+
transport,
178+
inner,
179+
write: Write { buffer },
136180
read: _,
137181
} = self.get_mut();
138182

183+
if buffer.length != 0 {
184+
log::trace!("{ptr:x?} flushing write buffer before shutdown");
185+
let () = ready!(buffer.poll_flush(ptr, cx, transport, inner))?;
186+
187+
debug_assert_eq!(buffer.length, 0);
188+
}
189+
190+
inner.poll_close_unpin(cx)
191+
}
192+
}
193+
194+
impl WriteBuffer {
195+
fn poll_flush<S: Transport + Unpin>(
196+
&mut self,
197+
ptr: *const NoiseStream<S>,
198+
cx: &mut Context<'_>,
199+
transport: &mut ClientConnection,
200+
inner: &mut S,
201+
) -> Poll<Result<(), IoError>> {
202+
// Check to see if the inner sink is ready before doing anything expensive
203+
// or destructive.
139204
let () = ready!(inner.poll_ready_unpin(cx))?;
140205

141-
let WriteBufferPolicy::NoBuffering = write.buffer_policy;
142-
log::trace!("{ptr:x?} encrypting {} bytes to send", buf.len());
143-
let ciphertext = transport.send(buf).map_err(IoError::other)?;
206+
let Self { length, bytes } = self;
207+
208+
log::trace!("{ptr:x?} encrypting {} bytes to send", length);
209+
let ciphertext = transport
210+
.send(&bytes[..usize::from(*length)])
211+
.map_err(IoError::other)?;
144212
log::trace!("{ptr:x?} encrypted to {} bytes", ciphertext.len());
145213

214+
*length = 0;
215+
146216
// Since the poll_ready above already succeeded, we can just send!
147217
inner.start_send_unpin((FrameType::Data, ciphertext.into()))?;
148218

149-
log::trace!("{ptr:x?} sent, waiting for next block");
150-
151-
Poll::Ready(Ok(buf.len()))
219+
log::trace!("{ptr:x?} flushed write buffer");
220+
Poll::Ready(Ok(()))
152221
}
153222

154-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
155-
self.get_mut().inner.poll_flush_unpin(cx)
223+
fn copy_prefix(&mut self, buf: &[u8]) -> usize {
224+
let Self { bytes, length } = self;
225+
let bytes_remaining = bytes.len() - usize::from(*length);
226+
227+
let to_copy = buf.len().min(bytes_remaining);
228+
bytes[(*length).into()..][..to_copy].copy_from_slice(&buf[..to_copy]);
229+
*length += u16::try_from(to_copy).expect("small buffer");
230+
231+
to_copy
156232
}
233+
}
157234

158-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
159-
self.get_mut().inner.poll_close_unpin(cx)
235+
impl Default for WriteBuffer {
236+
fn default() -> Self {
237+
Self {
238+
length: 0,
239+
bytes: Box::new([0; NOISE_TRANSPORT_PER_PAYLOAD_MAX]),
240+
}
160241
}
161242
}
162243

163244
#[cfg(test)]
164245
mod test {
165246
use std::io::ErrorKind as IoErrorKind;
247+
use std::pin::pin;
166248
use std::sync::Arc;
167249

168250
use assert_matches::assert_matches;
251+
use attest::client_connection::NOISE_TRANSPORT_PER_PACKET_MAX;
169252
use const_str::concat;
170253
use futures_util::stream::FusedStream;
171-
use futures_util::{pin_mut, FutureExt, Sink, Stream};
254+
use futures_util::{FutureExt, Sink, Stream};
172255
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
173256

174257
use super::*;
@@ -189,17 +272,83 @@ mod test {
189272
let (mut a, mut b) = new_stream_pair();
190273

191274
a.write_all(b"abcde").await.unwrap();
275+
a.flush().await.unwrap();
192276
let mut buf = [0; 5];
193277
assert_eq!(buf.len(), b.read(&mut buf).await.unwrap());
194278
assert_eq!(&buf, b"abcde");
195279

196280
b.write_all(b"1234567890").await.unwrap();
281+
b.flush().await.unwrap();
197282
b.write_all(b"abcdefghij").await.unwrap();
283+
b.flush().await.unwrap();
198284
let mut buf = [0; 20];
199285
a.read_exact(&mut buf).await.unwrap();
200286
assert_eq!(&buf, b"1234567890abcdefghij");
201287
}
202288

289+
#[tokio::test]
290+
async fn write_is_buffered() {
291+
// MITM the two streams so we can see when blocks pass through.
292+
let (transport_a, transport_d) = new_handshaken_pair().unwrap();
293+
let (a, mut b) = new_transport_pair(2);
294+
let (mut c, d) = new_transport_pair(2);
295+
let mut a = NoiseStream::new(a, transport_a, vec![0; 32]);
296+
let mut d = NoiseStream::new(d, transport_d, vec![0; 32]);
297+
298+
a.write_all(&[b'a'; NOISE_TRANSPORT_PER_PAYLOAD_MAX - 1])
299+
.await
300+
.unwrap();
301+
assert_matches!(b.next().now_or_never(), None);
302+
303+
a.write_all(&[b'b'; NOISE_TRANSPORT_PER_PAYLOAD_MAX + 1])
304+
.await
305+
.unwrap();
306+
307+
// The second write should have spilled the buffer into the stream,
308+
// resulting in one block sent.
309+
let first_block = b.next().await.expect("received").expect("msg");
310+
assert_matches!(b.next().now_or_never(), None);
311+
312+
assert!(
313+
first_block.len() <= NOISE_TRANSPORT_PER_PACKET_MAX,
314+
"first_block.len() = {}",
315+
first_block.len()
316+
);
317+
318+
c.send((FrameType::Data, first_block)).await.unwrap();
319+
let mut buf = [0; NOISE_TRANSPORT_PER_PAYLOAD_MAX];
320+
d.read_exact(&mut buf).await.expect("can read");
321+
322+
assert_eq!(
323+
buf.split_last(),
324+
Some((
325+
&b'b',
326+
[b'a'; NOISE_TRANSPORT_PER_PAYLOAD_MAX - 1].as_slice()
327+
))
328+
);
329+
330+
a.flush().await.unwrap();
331+
c.send((FrameType::Data, b.next().await.unwrap().unwrap()))
332+
.await
333+
.unwrap();
334+
335+
let mut buf = [0; NOISE_TRANSPORT_PER_PAYLOAD_MAX];
336+
d.read_exact(&mut buf).await.expect("can read");
337+
assert_eq!(buf, [b'b'; NOISE_TRANSPORT_PER_PAYLOAD_MAX].as_slice());
338+
}
339+
340+
#[tokio::test]
341+
async fn write_flushes_on_shutdown() {
342+
let (mut a, mut b) = new_stream_pair();
343+
344+
a.write_all(b"abcdef").await.unwrap();
345+
a.shutdown().await.unwrap();
346+
347+
let mut buf = vec![];
348+
b.read_to_end(&mut buf).await.expect("can read");
349+
assert_eq!(&buf, b"abcdef");
350+
}
351+
203352
#[tokio::test]
204353
async fn graceful_close() {
205354
const MESSAGE: &[u8] = b"message";
@@ -257,10 +406,11 @@ mod test {
257406
let (inner, other) = new_transport_pair(100);
258407
let mut stream = NoiseStream::new(inner, transport, vec![0u8; 32]);
259408

260-
// Drop the read end. With nobody to receive sent bytes, the write to
261-
// the underlying channel should fail.
409+
// Drop the read end. With nobody to receive sent bytes, the write and
410+
// flush to the underlying channel should fail.
262411
drop(other);
263-
assert_matches!(stream.write_all(b"ababcdcdefef").await, Err(_));
412+
assert_matches!(stream.write_all(b"ababcdcdefef").await, Ok(()));
413+
assert_matches!(stream.flush().await, Err(_));
264414
}
265415

266416
#[tokio::test]
@@ -386,28 +536,30 @@ mod test {
386536
let mut b = NoiseStream::new(b, transport_b, vec![0u8; 32]);
387537

388538
assert_matches!(a.write(b"first message").now_or_never(), Some(Ok(13)));
539+
assert_matches!(a.flush().now_or_never(), Some(Ok(())));
389540
assert_matches!(a.write(b"second message").now_or_never(), Some(Ok(14)));
541+
assert_matches!(a.flush().now_or_never(), Some(Ok(())));
390542

391-
let a_write = a.write(b"third message");
392-
pin_mut!(a_write);
393-
let a_write_waker = Arc::new(TestWaker::default());
543+
assert_matches!(a.write(b"third message").now_or_never(), Some(Ok(13)));
544+
let mut a_flush = pin!(a.flush());
545+
let a_flush_waker = Arc::new(TestWaker::default());
394546
assert_matches!(
395-
a_write.poll_unpin(&mut std::task::Context::from_waker(
396-
&Arc::clone(&a_write_waker).into()
547+
a_flush.poll_unpin(&mut std::task::Context::from_waker(
548+
&Arc::clone(&a_flush_waker).into()
397549
)),
398550
Poll::Pending
399551
);
400-
assert!(!a_write_waker.was_woken());
552+
assert!(!a_flush_waker.was_woken());
401553

402554
let mut read_buf = vec![0; 64];
403555
assert_matches!(b.read(&mut read_buf).now_or_never(), Some(Ok(13)));
404556
assert_eq!(&read_buf[..13], b"first message");
405557

406558
// Reading a message from the stream should unblock the writer.
407-
assert!(a_write_waker.was_woken());
559+
assert!(a_flush_waker.was_woken());
408560
assert_matches!(
409-
a_write.poll_unpin(&mut std::task::Context::from_waker(&a_write_waker.into())),
410-
Poll::Ready(Ok(13))
561+
a_flush.poll_unpin(&mut std::task::Context::from_waker(&a_flush_waker.into())),
562+
Poll::Ready(Ok(()))
411563
);
412564

413565
drop(a);

0 commit comments

Comments
 (0)