diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 42f99175f..fc8bf6b22 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -63,6 +63,30 @@ impl Display for ConnectError { #[cfg(feature = "std")] impl std::error::Error for ConnectError {} +/// Error returned by set_* +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ArgumentError { + InvalidArgs, + InvalidState, + InsufficientResource, +} + +impl Display for crate::socket::tcp::ArgumentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + crate::socket::tcp::ArgumentError::InvalidArgs => write!(f, "invalid arguments by RFC"), + crate::socket::tcp::ArgumentError::InvalidState => write!(f, "invalid state"), + crate::socket::tcp::ArgumentError::InsufficientResource => { + write!(f, "insufficient runtime resource") + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for crate::socket::tcp::ArgumentError {} + /// Error returned by [`Socket::send`] #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -774,6 +798,41 @@ impl<'a> Socket<'a> { } } + /// Return the local receive window scaling factor defined in [RFC 1323]. + /// + /// The value will become constant after the connection is established. + /// It may be reset to 0 during the handshake if remote side does not support window scaling. + pub fn local_recv_win_scale(&self) -> u8 { + self.remote_win_shift + } + + /// Set the local receive window scaling factor defined in [RFC 1323]. + /// + /// The value will become constant after the connection is established. + /// It may be reset to 0 during the handshake if remote side does not support window scaling. + /// + /// # Errors + /// `Err(ArgumentError::InvalidArgs)` if the scale is greater than 14. + /// `Err(ArgumentError::InvalidState)` if the socket is not in the `Closed` or `Listen` state. + /// `Err(ArgumentError::InsufficientResource)` if the receive buffer is smaller than (1< Result<(), ArgumentError> { + if scale > 14 { + return Err(ArgumentError::InvalidArgs); + } + + if self.rx_buffer.capacity() < (1 << scale) as usize { + return Err(ArgumentError::InsufficientResource); + } + + match self.state { + State::Closed | State::Listen => { + self.remote_win_shift = scale; + Ok(()) + } + _ => Err(ArgumentError::InvalidState), + } + } + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. /// /// See also the [set_hop_limit](#method.set_hop_limit) method @@ -828,6 +887,7 @@ impl<'a> Socket<'a> { fn reset(&mut self) { let rx_cap_log2 = mem::size_of::() * 8 - self.rx_buffer.capacity().leading_zeros() as usize; + let new_rx_win_shift = rx_cap_log2.saturating_sub(16) as u8; self.state = State::Closed; self.timer = Timer::new(); @@ -845,7 +905,10 @@ impl<'a> Socket<'a> { self.remote_last_win = 0; self.remote_win_len = 0; self.remote_win_scale = None; - self.remote_win_shift = rx_cap_log2.saturating_sub(16) as u8; + // keep user-specified window scaling across connect()/listen() + if self.remote_win_shift < new_rx_win_shift { + self.remote_win_shift = new_rx_win_shift; + } self.remote_mss = DEFAULT_MSS; self.remote_last_ts = None; self.ack_delay_timer = AckDelayTimer::Idle; @@ -2329,6 +2392,7 @@ impl<'a> Socket<'a> { } else if self.timer.should_close(cx.now()) { // If we have spent enough time in the TIME-WAIT state, close the socket. tcp_trace!("TIME-WAIT timer expired"); + self.remote_win_shift = 0; self.reset(); return Ok(()); } else { @@ -2601,6 +2665,63 @@ impl<'a> Socket<'a> { .unwrap_or(&PollAt::Ingress) } } + + /// Replace the receive buffer with a new one. + /// + /// The requirements for the new buffer are: + /// 1. The new buffer must be larger than the length of remaining data in the current buffer + /// 2. The new buffer must be multiple of (1 << self.remote_win_shift) + /// + /// If the new buffer does not meet the requirements, the new buffer is returned as an error; + /// otherwise, the old buffer is returned as an Ok value. + /// + /// See also the [local_recv_win_scale](struct.Socket.html#method.local_recv_win_scale) methods. + pub fn replace_recv_buffer>>( + &mut self, + new_buffer: T, + ) -> Result, SocketBuffer<'a>> { + let mut replaced_buf = new_buffer.into(); + /* Check if the new buffer is valid + * Requirements: + * 1. The new buffer must be larger than the length of remaining data in the current buffer + * 2. The new buffer must be multiple of (1 << self.remote_win_shift) + */ + if replaced_buf.capacity() < self.rx_buffer.len() + || replaced_buf.capacity() % (1 << self.remote_win_shift) != 0 + { + return Err(replaced_buf); + } + replaced_buf.clear(); + + // We should copy both allocated data and unallocated data (for assembler) + let allocated1 = self.rx_buffer.get_allocated(0, self.rx_buffer.len()); + let l = replaced_buf.enqueue_slice(allocated1); + assert_eq!(l, allocated1.len()); + if allocated1.len() < self.rx_buffer.len() { + let allocated2 = self + .rx_buffer + .get_allocated(allocated1.len(), self.rx_buffer.len() - allocated1.len()); + let l = replaced_buf.enqueue_slice(allocated2); + assert_eq!(l, allocated2.len()); + } + + // make sure assembler can work properly + let unallocated1 = self.rx_buffer.get_unallocated(0, self.rx_buffer.window()); + let unallocated1_len = unallocated1.len(); + let l = replaced_buf.write_unallocated(0, unallocated1); + assert_eq!(l, unallocated1.len()); + if unallocated1_len < self.rx_buffer.window() { + let unallocated2 = self + .rx_buffer + .get_unallocated(unallocated1_len, self.rx_buffer.window() - unallocated1_len); + let l = replaced_buf.write_unallocated(unallocated1_len, unallocated2); + assert_eq!(l, unallocated2.len()); + } + assert_eq!(replaced_buf.len(), self.rx_buffer.len()); + + mem::swap(&mut self.rx_buffer, &mut replaced_buf); + Ok(replaced_buf) + } } impl<'a> fmt::Write for Socket<'a> { @@ -8151,4 +8272,70 @@ mod test { }] ); } + + // =========================================================================================// + // Tests for window scaling + // =========================================================================================// + + fn socket_established_with_window_scaling() -> TestSocket { + let mut s = socket_established(); + s.remote_win_shift = 10; + const BASE: usize = 1 << 10; + s.tx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]); + s.rx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]); + s + } + + #[test] + fn test_too_large_window_scale() { + let mut socket = Socket::new( + SocketBuffer::new(vec![0; 8 * (1 << 15)]), + SocketBuffer::new(vec![0; 8 * (1 << 15)]), + ); + assert!(socket.set_local_recv_win_scale(15).is_err()) + } + + #[test] + fn test_set_window_scale() { + let mut socket = Socket::new( + SocketBuffer::new(vec![0; 128]), + SocketBuffer::new(vec![0; 128]), + ); + assert!(matches!(socket.state, State::Closed)); + assert_eq!(socket.rx_buffer.capacity(), 128); + assert!(socket.set_local_recv_win_scale(6).is_ok()); + assert!(socket.set_local_recv_win_scale(14).is_err()); + assert_eq!(socket.local_recv_win_scale(), 6); + } + + #[test] + fn test_set_scale_with_tcp_state() { + let mut socket = socket(); + assert!(socket.set_local_recv_win_scale(1).is_ok()); + let mut socket = socket_established(); + assert!(socket.set_local_recv_win_scale(1).is_err()); + let mut socket = socket_listen(); + assert!(socket.set_local_recv_win_scale(1).is_ok()); + let mut socket = socket_syn_received(); + assert!(socket.set_local_recv_win_scale(1).is_err()); + } + + #[test] + fn test_resize_recv_buffer_invalid_size() { + let mut s = socket_established_with_window_scaling(); + assert_eq!(s.rx_buffer.enqueue_slice(&[42; 31 * 1024]), 31 * 1024); + assert_eq!(s.rx_buffer.len(), 31 * 1024); + assert!(s + .replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024 + 512])) + .is_err()); + assert!(s + .replace_recv_buffer(SocketBuffer::new(vec![7u8; 16 * 1024])) + .is_err()); + let old_buffer = s + .replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024])) + .unwrap(); + assert_eq!(old_buffer.capacity(), 64 * 1024); + assert_eq!(s.rx_buffer.len(), 31 * 1024); + assert_eq!(s.rx_buffer.capacity(), 32 * 1024); + } }