@@ -7,7 +7,7 @@ use std::io::Error as IoError;
7
7
use std:: pin:: Pin ;
8
8
use std:: task:: { ready, Context , Poll } ;
9
9
10
- use attest:: client_connection:: ClientConnection ;
10
+ use attest:: client_connection:: { ClientConnection , NOISE_TRANSPORT_PER_PAYLOAD_MAX } ;
11
11
use bytes:: Bytes ;
12
12
use futures_util:: stream:: FusedStream ;
13
13
use futures_util:: { SinkExt as _, StreamExt as _} ;
@@ -33,7 +33,7 @@ pub struct NoiseStream<S> {
33
33
34
34
#[ derive( Debug , Default ) ]
35
35
struct Write {
36
- buffer_policy : WriteBufferPolicy ,
36
+ buffer : WriteBuffer ,
37
37
}
38
38
39
39
#[ derive( Debug , Default ) ]
@@ -43,10 +43,10 @@ enum Read {
43
43
ReadFromBlock ( Bytes ) ,
44
44
}
45
45
46
- #[ derive( Debug , Default ) ]
47
- enum WriteBufferPolicy {
48
- # [ default ]
49
- NoBuffering ,
46
+ #[ derive( Debug ) ]
47
+ struct WriteBuffer {
48
+ length : u16 ,
49
+ bytes : Box < [ u8 ; NOISE_TRANSPORT_PER_PAYLOAD_MAX ] > ,
50
50
}
51
51
52
52
impl < S > NoiseStream < S > {
@@ -132,43 +132,126 @@ impl<S: Transport + Unpin> AsyncWrite for NoiseStream<S> {
132
132
let Self {
133
133
transport,
134
134
inner,
135
- write,
135
+ write : Write { buffer } ,
136
+ read : _,
137
+ } = self . get_mut ( ) ;
138
+
139
+ let bytes_remaining = buffer. bytes . len ( ) - usize:: from ( buffer. length ) ;
140
+
141
+ if bytes_remaining == 0 {
142
+ // We need to make space by flushing the contents of the buffer.
143
+ let ( ) = ready ! ( buffer. poll_flush( ptr, cx, transport, inner) ) ?;
144
+
145
+ debug_assert_eq ! ( buffer. length, 0 ) ;
146
+ }
147
+
148
+ let count = buffer. copy_prefix ( buf) ;
149
+ log:: trace!( "{ptr:x?} buffered {count} bytes" ) ;
150
+ Poll :: Ready ( Ok ( count) )
151
+ }
152
+
153
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , IoError > > {
154
+ let ptr = & * self as * const Self ;
155
+
156
+ let Self {
157
+ transport,
158
+ inner,
159
+ write : Write { buffer } ,
160
+ read : _,
161
+ } = self . get_mut ( ) ;
162
+
163
+ if buffer. length != 0 {
164
+ log:: trace!( "{ptr:x?} trying to flush write buffer" ) ;
165
+ let ( ) = ready ! ( buffer. poll_flush( ptr, cx, transport, inner) ) ?;
166
+
167
+ debug_assert_eq ! ( buffer. length, 0 ) ;
168
+ }
169
+
170
+ inner. poll_flush_unpin ( cx)
171
+ }
172
+
173
+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , IoError > > {
174
+ let ptr = & * self as * const Self ;
175
+
176
+ let Self {
177
+ transport,
178
+ inner,
179
+ write : Write { buffer } ,
136
180
read : _,
137
181
} = self . get_mut ( ) ;
138
182
183
+ if buffer. length != 0 {
184
+ log:: trace!( "{ptr:x?} flushing write buffer before shutdown" ) ;
185
+ let ( ) = ready ! ( buffer. poll_flush( ptr, cx, transport, inner) ) ?;
186
+
187
+ debug_assert_eq ! ( buffer. length, 0 ) ;
188
+ }
189
+
190
+ inner. poll_close_unpin ( cx)
191
+ }
192
+ }
193
+
194
+ impl WriteBuffer {
195
+ fn poll_flush < S : Transport + Unpin > (
196
+ & mut self ,
197
+ ptr : * const NoiseStream < S > ,
198
+ cx : & mut Context < ' _ > ,
199
+ transport : & mut ClientConnection ,
200
+ inner : & mut S ,
201
+ ) -> Poll < Result < ( ) , IoError > > {
202
+ // Check to see if the inner sink is ready before doing anything expensive
203
+ // or destructive.
139
204
let ( ) = ready ! ( inner. poll_ready_unpin( cx) ) ?;
140
205
141
- let WriteBufferPolicy :: NoBuffering = write. buffer_policy ;
142
- log:: trace!( "{ptr:x?} encrypting {} bytes to send" , buf. len( ) ) ;
143
- let ciphertext = transport. send ( buf) . map_err ( IoError :: other) ?;
206
+ let Self { length, bytes } = self ;
207
+
208
+ log:: trace!( "{ptr:x?} encrypting {} bytes to send" , length) ;
209
+ let ciphertext = transport
210
+ . send ( & bytes[ ..usize:: from ( * length) ] )
211
+ . map_err ( IoError :: other) ?;
144
212
log:: trace!( "{ptr:x?} encrypted to {} bytes" , ciphertext. len( ) ) ;
145
213
214
+ * length = 0 ;
215
+
146
216
// Since the poll_ready above already succeeded, we can just send!
147
217
inner. start_send_unpin ( ( FrameType :: Data , ciphertext. into ( ) ) ) ?;
148
218
149
- log:: trace!( "{ptr:x?} sent, waiting for next block" ) ;
150
-
151
- Poll :: Ready ( Ok ( buf. len ( ) ) )
219
+ log:: trace!( "{ptr:x?} flushed write buffer" ) ;
220
+ Poll :: Ready ( Ok ( ( ) ) )
152
221
}
153
222
154
- fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , IoError > > {
155
- self . get_mut ( ) . inner . poll_flush_unpin ( cx)
223
+ fn copy_prefix ( & mut self , buf : & [ u8 ] ) -> usize {
224
+ let Self { bytes, length } = self ;
225
+ let bytes_remaining = bytes. len ( ) - usize:: from ( * length) ;
226
+
227
+ let to_copy = buf. len ( ) . min ( bytes_remaining) ;
228
+ bytes[ ( * length) . into ( ) ..] [ ..to_copy] . copy_from_slice ( & buf[ ..to_copy] ) ;
229
+ * length += u16:: try_from ( to_copy) . expect ( "small buffer" ) ;
230
+
231
+ to_copy
156
232
}
233
+ }
157
234
158
- fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , IoError > > {
159
- self . get_mut ( ) . inner . poll_close_unpin ( cx)
235
+ impl Default for WriteBuffer {
236
+ fn default ( ) -> Self {
237
+ Self {
238
+ length : 0 ,
239
+ bytes : Box :: new ( [ 0 ; NOISE_TRANSPORT_PER_PAYLOAD_MAX ] ) ,
240
+ }
160
241
}
161
242
}
162
243
163
244
#[ cfg( test) ]
164
245
mod test {
165
246
use std:: io:: ErrorKind as IoErrorKind ;
247
+ use std:: pin:: pin;
166
248
use std:: sync:: Arc ;
167
249
168
250
use assert_matches:: assert_matches;
251
+ use attest:: client_connection:: NOISE_TRANSPORT_PER_PACKET_MAX ;
169
252
use const_str:: concat;
170
253
use futures_util:: stream:: FusedStream ;
171
- use futures_util:: { pin_mut , FutureExt , Sink , Stream } ;
254
+ use futures_util:: { FutureExt , Sink , Stream } ;
172
255
use tokio:: io:: { AsyncReadExt as _, AsyncWriteExt as _} ;
173
256
174
257
use super :: * ;
@@ -189,17 +272,83 @@ mod test {
189
272
let ( mut a, mut b) = new_stream_pair ( ) ;
190
273
191
274
a. write_all ( b"abcde" ) . await . unwrap ( ) ;
275
+ a. flush ( ) . await . unwrap ( ) ;
192
276
let mut buf = [ 0 ; 5 ] ;
193
277
assert_eq ! ( buf. len( ) , b. read( & mut buf) . await . unwrap( ) ) ;
194
278
assert_eq ! ( & buf, b"abcde" ) ;
195
279
196
280
b. write_all ( b"1234567890" ) . await . unwrap ( ) ;
281
+ b. flush ( ) . await . unwrap ( ) ;
197
282
b. write_all ( b"abcdefghij" ) . await . unwrap ( ) ;
283
+ b. flush ( ) . await . unwrap ( ) ;
198
284
let mut buf = [ 0 ; 20 ] ;
199
285
a. read_exact ( & mut buf) . await . unwrap ( ) ;
200
286
assert_eq ! ( & buf, b"1234567890abcdefghij" ) ;
201
287
}
202
288
289
+ #[ tokio:: test]
290
+ async fn write_is_buffered ( ) {
291
+ // MITM the two streams so we can see when blocks pass through.
292
+ let ( transport_a, transport_d) = new_handshaken_pair ( ) . unwrap ( ) ;
293
+ let ( a, mut b) = new_transport_pair ( 2 ) ;
294
+ let ( mut c, d) = new_transport_pair ( 2 ) ;
295
+ let mut a = NoiseStream :: new ( a, transport_a, vec ! [ 0 ; 32 ] ) ;
296
+ let mut d = NoiseStream :: new ( d, transport_d, vec ! [ 0 ; 32 ] ) ;
297
+
298
+ a. write_all ( & [ b'a' ; NOISE_TRANSPORT_PER_PAYLOAD_MAX - 1 ] )
299
+ . await
300
+ . unwrap ( ) ;
301
+ assert_matches ! ( b. next( ) . now_or_never( ) , None ) ;
302
+
303
+ a. write_all ( & [ b'b' ; NOISE_TRANSPORT_PER_PAYLOAD_MAX + 1 ] )
304
+ . await
305
+ . unwrap ( ) ;
306
+
307
+ // The second write should have spilled the buffer into the stream,
308
+ // resulting in one block sent.
309
+ let first_block = b. next ( ) . await . expect ( "received" ) . expect ( "msg" ) ;
310
+ assert_matches ! ( b. next( ) . now_or_never( ) , None ) ;
311
+
312
+ assert ! (
313
+ first_block. len( ) <= NOISE_TRANSPORT_PER_PACKET_MAX ,
314
+ "first_block.len() = {}" ,
315
+ first_block. len( )
316
+ ) ;
317
+
318
+ c. send ( ( FrameType :: Data , first_block) ) . await . unwrap ( ) ;
319
+ let mut buf = [ 0 ; NOISE_TRANSPORT_PER_PAYLOAD_MAX ] ;
320
+ d. read_exact ( & mut buf) . await . expect ( "can read" ) ;
321
+
322
+ assert_eq ! (
323
+ buf. split_last( ) ,
324
+ Some ( (
325
+ & b'b' ,
326
+ [ b'a' ; NOISE_TRANSPORT_PER_PAYLOAD_MAX - 1 ] . as_slice( )
327
+ ) )
328
+ ) ;
329
+
330
+ a. flush ( ) . await . unwrap ( ) ;
331
+ c. send ( ( FrameType :: Data , b. next ( ) . await . unwrap ( ) . unwrap ( ) ) )
332
+ . await
333
+ . unwrap ( ) ;
334
+
335
+ let mut buf = [ 0 ; NOISE_TRANSPORT_PER_PAYLOAD_MAX ] ;
336
+ d. read_exact ( & mut buf) . await . expect ( "can read" ) ;
337
+ assert_eq ! ( buf, [ b'b' ; NOISE_TRANSPORT_PER_PAYLOAD_MAX ] . as_slice( ) ) ;
338
+ }
339
+
340
+ #[ tokio:: test]
341
+ async fn write_flushes_on_shutdown ( ) {
342
+ let ( mut a, mut b) = new_stream_pair ( ) ;
343
+
344
+ a. write_all ( b"abcdef" ) . await . unwrap ( ) ;
345
+ a. shutdown ( ) . await . unwrap ( ) ;
346
+
347
+ let mut buf = vec ! [ ] ;
348
+ b. read_to_end ( & mut buf) . await . expect ( "can read" ) ;
349
+ assert_eq ! ( & buf, b"abcdef" ) ;
350
+ }
351
+
203
352
#[ tokio:: test]
204
353
async fn graceful_close ( ) {
205
354
const MESSAGE : & [ u8 ] = b"message" ;
@@ -257,10 +406,11 @@ mod test {
257
406
let ( inner, other) = new_transport_pair ( 100 ) ;
258
407
let mut stream = NoiseStream :: new ( inner, transport, vec ! [ 0u8 ; 32 ] ) ;
259
408
260
- // Drop the read end. With nobody to receive sent bytes, the write to
261
- // the underlying channel should fail.
409
+ // Drop the read end. With nobody to receive sent bytes, the write and
410
+ // flush to the underlying channel should fail.
262
411
drop ( other) ;
263
- assert_matches ! ( stream. write_all( b"ababcdcdefef" ) . await , Err ( _) ) ;
412
+ assert_matches ! ( stream. write_all( b"ababcdcdefef" ) . await , Ok ( ( ) ) ) ;
413
+ assert_matches ! ( stream. flush( ) . await , Err ( _) ) ;
264
414
}
265
415
266
416
#[ tokio:: test]
@@ -386,28 +536,30 @@ mod test {
386
536
let mut b = NoiseStream :: new ( b, transport_b, vec ! [ 0u8 ; 32 ] ) ;
387
537
388
538
assert_matches ! ( a. write( b"first message" ) . now_or_never( ) , Some ( Ok ( 13 ) ) ) ;
539
+ assert_matches ! ( a. flush( ) . now_or_never( ) , Some ( Ok ( ( ) ) ) ) ;
389
540
assert_matches ! ( a. write( b"second message" ) . now_or_never( ) , Some ( Ok ( 14 ) ) ) ;
541
+ assert_matches ! ( a. flush( ) . now_or_never( ) , Some ( Ok ( ( ) ) ) ) ;
390
542
391
- let a_write = a. write ( b"third message" ) ;
392
- pin_mut ! ( a_write ) ;
393
- let a_write_waker = Arc :: new ( TestWaker :: default ( ) ) ;
543
+ assert_matches ! ( a. write( b"third message" ) . now_or_never ( ) , Some ( Ok ( 13 ) ) ) ;
544
+ let mut a_flush = pin ! ( a . flush ( ) ) ;
545
+ let a_flush_waker = Arc :: new ( TestWaker :: default ( ) ) ;
394
546
assert_matches ! (
395
- a_write . poll_unpin( & mut std:: task:: Context :: from_waker(
396
- & Arc :: clone( & a_write_waker ) . into( )
547
+ a_flush . poll_unpin( & mut std:: task:: Context :: from_waker(
548
+ & Arc :: clone( & a_flush_waker ) . into( )
397
549
) ) ,
398
550
Poll :: Pending
399
551
) ;
400
- assert ! ( !a_write_waker . was_woken( ) ) ;
552
+ assert ! ( !a_flush_waker . was_woken( ) ) ;
401
553
402
554
let mut read_buf = vec ! [ 0 ; 64 ] ;
403
555
assert_matches ! ( b. read( & mut read_buf) . now_or_never( ) , Some ( Ok ( 13 ) ) ) ;
404
556
assert_eq ! ( & read_buf[ ..13 ] , b"first message" ) ;
405
557
406
558
// Reading a message from the stream should unblock the writer.
407
- assert ! ( a_write_waker . was_woken( ) ) ;
559
+ assert ! ( a_flush_waker . was_woken( ) ) ;
408
560
assert_matches ! (
409
- a_write . poll_unpin( & mut std:: task:: Context :: from_waker( & a_write_waker . into( ) ) ) ,
410
- Poll :: Ready ( Ok ( 13 ) )
561
+ a_flush . poll_unpin( & mut std:: task:: Context :: from_waker( & a_flush_waker . into( ) ) ) ,
562
+ Poll :: Ready ( Ok ( ( ) ) )
411
563
) ;
412
564
413
565
drop ( a) ;
0 commit comments