diff --git a/quinn-proto/src/address_discovery.rs b/quinn-proto/src/address_discovery.rs new file mode 100644 index 0000000000..2d9e0261cf --- /dev/null +++ b/quinn-proto/src/address_discovery.rs @@ -0,0 +1,71 @@ +//! Address discovery types from +//! + +use crate::coding::BufExt; +use crate::{transport_parameters::Error, VarInt}; + +/// The role of each participant. +/// +/// When enabled, this is reported as a transport parameter. +#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)] +pub(crate) struct Role { + pub(crate) send_reports: bool, + pub(crate) receive_reports: bool, +} + +impl TryFrom for Role { + type Error = Error; + + fn try_from(value: VarInt) -> Result { + let mut role = Self::default(); + match value.0 { + 0 => role.send_reports = true, + 1 => role.receive_reports = true, + 2 => { + role.send_reports = true; + role.receive_reports = true; + } + _ => return Err(Error::IllegalValue), + } + + Ok(role) + } +} + +impl Role { + pub(crate) fn from_transport_parameter( + len: usize, + role: &Role, + r: &mut impl bytes::Buf, + ) -> Result { + if !role.is_disabled() { + // duplicate parameter + return Err(Error::Malformed); + } + let value: VarInt = r.get()?; + if len != value.size() { + return Err(Error::Malformed); + } + + value.try_into() + } + /// Whether address discovery is disabled. + pub(crate) fn is_disabled(&self) -> bool { + !self.receive_reports && !self.send_reports + } + + /// Whether this peer should report observed addresses to the other peer. + pub(crate) fn should_report(&self, other: &Self) -> bool { + self.send_reports && other.receive_reports + } + + /// Gives the [`VarInt`] representing this [`Role`] as a transport parameter. + pub(crate) fn as_transport_parameter(&self) -> Option { + match (self.send_reports, self.receive_reports) { + (true, true) => Some(VarInt(2)), + (true, false) => Some(VarInt(0)), + (false, true) => Some(VarInt(1)), + (false, false) => None, + } + } +} diff --git a/quinn-proto/src/config/transport.rs b/quinn-proto/src/config/transport.rs index 82ddd1d19a..d7e1abc212 100644 --- a/quinn-proto/src/config/transport.rs +++ b/quinn-proto/src/config/transport.rs @@ -1,6 +1,9 @@ use std::{fmt, sync::Arc}; -use crate::{congestion, Duration, VarInt, VarIntBoundsExceeded, INITIAL_MTU, MAX_UDP_PAYLOAD}; +use crate::{ + address_discovery, congestion, Duration, VarInt, VarIntBoundsExceeded, INITIAL_MTU, + MAX_UDP_PAYLOAD, +}; /// Parameters governing the core QUIC state machine /// @@ -43,6 +46,8 @@ pub struct TransportConfig { pub(crate) congestion_controller_factory: Arc, pub(crate) enable_segmentation_offload: bool, + + pub(crate) address_discovery_role: crate::address_discovery::Role, } impl TransportConfig { @@ -314,6 +319,26 @@ impl TransportConfig { self.enable_segmentation_offload = enabled; self } + + /// Whether to send observed address reports to peers. + /// + /// This will aid peers in inferring their reachable address, which in most NATd networks + /// will not be easily available to them. + pub fn send_observed_address_reports(&mut self, enabled: bool) -> &mut Self { + self.address_discovery_role.send_reports = enabled; + self + } + + /// Whether to receive observed address reports from other peers. + /// + /// Peers with the address discovery extension enabled that are willing to provide observed + /// address reports will do so if this transport parameter is set. In general, observed address + /// reports cannot be trusted. This, however, can aid the current endpoint in inferring its + /// reachable address, which in most NATd networks will not be easily available. + pub fn receive_observed_address_reports(&mut self, enabled: bool) -> &mut Self { + self.address_discovery_role.receive_reports = enabled; + self + } } impl Default for TransportConfig { @@ -354,6 +379,8 @@ impl Default for TransportConfig { congestion_controller_factory: Arc::new(congestion::CubicConfig::default()), enable_segmentation_offload: true, + + address_discovery_role: address_discovery::Role::default(), } } } @@ -385,6 +412,7 @@ impl fmt::Debug for TransportConfig { deterministic_packet_numbers: _, congestion_controller_factory: _, enable_segmentation_offload, + address_discovery_role, } = self; fmt.debug_struct("TransportConfig") .field("max_concurrent_bidi_streams", max_concurrent_bidi_streams) @@ -412,6 +440,7 @@ impl fmt::Debug for TransportConfig { .field("datagram_send_buffer_size", datagram_send_buffer_size) // congestion_controller_factory not debug .field("enable_segmentation_offload", enable_segmentation_offload) + .field("address_discovery_role", address_discovery_role) .finish_non_exhaustive() } } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 0364fe87bf..1f73fd6094 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -19,7 +19,7 @@ use crate::{ coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame::{self, Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct, ObservedAddr}, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -222,6 +222,12 @@ pub struct Connection { /// no outgoing application data. app_limited: bool, + // + // ObservedAddr + // + /// Sequence number for the next observed address frame sent to the peer. + next_observed_addr_seq_no: VarInt, + streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, @@ -336,6 +342,8 @@ impl Connection { receiving_ecn: false, total_authed_packets: 0, + next_observed_addr_seq_no: 0u32.into(), + streams: StreamsState::new( side, config.max_concurrent_uni_streams, @@ -2633,6 +2641,9 @@ impl Connection { let mut close = None; let payload_len = payload.len(); let mut ack_eliciting = false; + // if this packet triggers a path migration and includes a observed address frame, it's + // stored here + let mut migration_observed_addr = None; for result in frame::Iter::new(payload)? { let frame = result?; let span = match frame { @@ -2676,7 +2687,8 @@ impl Connection { Frame::Padding | Frame::PathChallenge(_) | Frame::PathResponse(_) - | Frame::NewConnectionId(_) => {} + | Frame::NewConnectionId(_) + | Frame::ObservedAddr(_) => {} _ => { is_probing_packet = false; } @@ -2904,6 +2916,33 @@ impl Connection { self.discard_space(now, SpaceId::Handshake); } } + Frame::ObservedAddr(observed) => { + // check if params allows the peer to send report and this node to receive it + if !self + .peer_params + .address_discovery_role + .should_report(&self.config.address_discovery_role) + { + return Err(TransportError::PROTOCOL_VIOLATION( + "received OBSERVED_ADDRESS frame when not negotiated", + )); + } + // must only be sent in data space + if packet.header.space() != SpaceId::Data { + return Err(TransportError::PROTOCOL_VIOLATION( + "OBSERVED_ADDRESS frame outside data space", + )); + } + + if remote == self.path.remote { + if let Some(updated) = self.path.update_observed_addr_report(observed) { + self.events.push_back(Event::ObservedAddr(updated)); + } + } else { + // include in migration + migration_observed_addr = Some(observed) + } + } } } @@ -2940,7 +2979,7 @@ impl Connection { server_config.migration, "migration-initiating packets should have been dropped immediately" ); - self.migrate(now, remote); + self.migrate(now, remote, migration_observed_addr); // Break linkability, if possible self.update_rem_cid(); self.spin = false; @@ -2949,7 +2988,7 @@ impl Connection { Ok(()) } - fn migrate(&mut self, now: Instant, remote: SocketAddr) { + fn migrate(&mut self, now: Instant, remote: SocketAddr, observed_addr: Option) { trace!(%remote, "migration initiated"); // Reset rtt/congestion state for new path unless it looks like a NAT rebinding. // Note that the congestion window will not grow until validation terminates. Helps mitigate @@ -2969,6 +3008,12 @@ impl Connection { &self.config, ) }; + new_path.last_observed_addr_report = self.path.last_observed_addr_report; + if let Some(report) = observed_addr { + if let Some(updated) = new_path.update_observed_addr_report(report) { + self.events.push_back(Event::ObservedAddr(updated)); + } + } new_path.challenge = Some(self.rng.gen()); new_path.challenge_pending = true; let prev_pto = self.pto(SpaceId::Data); @@ -3053,6 +3098,53 @@ impl Connection { self.stats.frame_tx.handshake_done.saturating_add(1); } + // OBSERVED_ADDR + let mut send_observed_address = + |space_id: SpaceId, + buf: &mut Vec, + max_size: usize, + space: &mut PacketSpace, + sent: &mut SentFrames, + stats: &mut ConnectionStats, + skip_sent_check: bool| { + // should only be sent within Data space and only if allowed by extension + // negotiation + // send is also skipped if the path has already sent an observed address + let send_allowed = self + .config + .address_discovery_role + .should_report(&self.peer_params.address_discovery_role); + let send_required = + space.pending.observed_addr || !self.path.observed_addr_sent || skip_sent_check; + if space_id != SpaceId::Data || !send_allowed || !send_required { + return; + } + + let observed = + frame::ObservedAddr::new(self.path.remote, self.next_observed_addr_seq_no); + + if buf.len() + observed.size() < max_size { + observed.write(buf); + + self.next_observed_addr_seq_no = + self.next_observed_addr_seq_no.saturating_add(1u8); + self.path.observed_addr_sent = true; + + stats.frame_tx.observed_addr += 1; + sent.retransmits.get_or_create().observed_addr = true; + space.pending.observed_addr = false; + } + }; + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + false, + ); + // PING if mem::replace(&mut space.ping_pending, false) { trace!("PING"); @@ -3122,7 +3214,16 @@ impl Connection { trace!("PATH_CHALLENGE {:08x}", token); buf.write(frame::FrameType::PATH_CHALLENGE); buf.write(token); - self.stats.frame_tx.path_challenge += 1; + + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + true, + ); } } @@ -3135,6 +3236,19 @@ impl Connection { buf.write(frame::FrameType::PATH_RESPONSE); buf.write(token); self.stats.frame_tx.path_response += 1; + + // NOTE: this is technically not required but might be useful to ride the + // request/response nature of path challenges to refresh an observation + // Since PATH_RESPONSE is a probing frame, this is allowed by the spec. + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + true, + ); } } @@ -3838,6 +3952,8 @@ pub enum Event { DatagramReceived, /// One or more application datagrams have been sent after blocking DatagramsUnblocked, + /// Received an observation of our external address from the peer. + ObservedAddr(SocketAddr), } fn instant_saturating_sub(x: Instant, y: Instant) -> Duration { diff --git a/quinn-proto/src/connection/paths.rs b/quinn-proto/src/connection/paths.rs index d9621cc4ea..b8b7d6bdd7 100644 --- a/quinn-proto/src/connection/paths.rs +++ b/quinn-proto/src/connection/paths.rs @@ -7,7 +7,10 @@ use super::{ pacing::Pacer, spaces::{PacketSpace, SentPacket}, }; -use crate::{congestion, packet::SpaceId, Duration, Instant, TransportConfig, TIMER_GRANULARITY}; +use crate::{ + congestion, frame::ObservedAddr, packet::SpaceId, Duration, Instant, TransportConfig, + TIMER_GRANULARITY, +}; /// Description of a particular network path pub(super) struct PathData { @@ -37,6 +40,11 @@ pub(super) struct PathData { /// Used in persistent congestion determination. pub(super) first_packet_after_rtt_sample: Option<(SpaceId, u64)>, pub(super) in_flight: InFlight, + /// Whether this path has had it's remote address reported back to the peer. This only happens + /// if both peers agree to so based on their transport parameters. + pub(super) observed_addr_sent: bool, + /// Observed address frame with the largest sequence number received from the peer on this path. + pub(super) last_observed_addr_report: Option, /// Number of the first packet sent on this path /// /// Used to determine whether a packet was sent on an earlier path. Insufficient to determine if @@ -90,10 +98,15 @@ impl PathData { ), first_packet_after_rtt_sample: None, in_flight: InFlight::new(), + observed_addr_sent: false, + last_observed_addr_report: None, first_packet: None, } } + /// Create a new path from a previous one. + /// + /// This should only be called when migrating paths. pub(super) fn from_previous(remote: SocketAddr, prev: &Self, now: Instant) -> Self { let congestion = prev.congestion.clone_box(); let smoothed_rtt = prev.rtt.get(); @@ -111,6 +124,8 @@ impl PathData { mtud: prev.mtud.clone(), first_packet_after_rtt_sample: prev.first_packet_after_rtt_sample, in_flight: InFlight::new(), + observed_addr_sent: false, + last_observed_addr_report: None, first_packet: None, } } @@ -156,6 +171,37 @@ impl PathData { self.in_flight.remove(packet); true } + + /// Updates the last observed address report received on this path. + /// + /// If the address was updated, it's returned to be informed to the application. + #[must_use = "updated observed address must be reported to the application"] + pub(super) fn update_observed_addr_report( + &mut self, + observed: ObservedAddr, + ) -> Option { + match self.last_observed_addr_report.as_mut() { + Some(prev) => { + if prev.seq_no >= observed.seq_no { + // frames that do not increase the sequence number on this path are ignored + None + } else if prev.ip == observed.ip && prev.port == observed.port { + // keep track of the last seq_no but do not report the address as updated + prev.seq_no = observed.seq_no; + None + } else { + let addr = observed.socket_addr(); + self.last_observed_addr_report = Some(observed); + Some(addr) + } + } + None => { + let addr = observed.socket_addr(); + self.last_observed_addr_report = Some(observed); + Some(addr) + } + } + } } /// RTT estimation for a particular network path diff --git a/quinn-proto/src/connection/spaces.rs b/quinn-proto/src/connection/spaces.rs index ed58b51c1e..0d0edad68d 100644 --- a/quinn-proto/src/connection/spaces.rs +++ b/quinn-proto/src/connection/spaces.rs @@ -309,6 +309,7 @@ pub struct Retransmits { pub(super) retire_cids: Vec, pub(super) ack_frequency: bool, pub(super) handshake_done: bool, + pub(super) observed_addr: bool, } impl Retransmits { @@ -326,6 +327,7 @@ impl Retransmits { && self.retire_cids.is_empty() && !self.ack_frequency && !self.handshake_done + && !self.observed_addr } } @@ -347,6 +349,7 @@ impl ::std::ops::BitOrAssign for Retransmits { self.retire_cids.extend(rhs.retire_cids); self.ack_frequency |= rhs.ack_frequency; self.handshake_done |= rhs.handshake_done; + self.observed_addr |= rhs.observed_addr; } } diff --git a/quinn-proto/src/connection/stats.rs b/quinn-proto/src/connection/stats.rs index 9ddb42d1a0..31f5f1d142 100644 --- a/quinn-proto/src/connection/stats.rs +++ b/quinn-proto/src/connection/stats.rs @@ -53,6 +53,7 @@ pub struct FrameStats { pub streams_blocked_uni: u64, pub stop_sending: u64, pub stream: u64, + pub observed_addr: u64, } impl FrameStats { @@ -93,6 +94,7 @@ impl FrameStats { Frame::AckFrequency(_) => self.ack_frequency += 1, Frame::ImmediateAck => self.immediate_ack += 1, Frame::HandshakeDone => self.handshake_done = self.handshake_done.saturating_add(1), + Frame::ObservedAddr(_) => self.observed_addr += 1, } } } diff --git a/quinn-proto/src/frame.rs b/quinn-proto/src/frame.rs index da5a53af5a..6a3edd9c49 100644 --- a/quinn-proto/src/frame.rs +++ b/quinn-proto/src/frame.rs @@ -1,6 +1,7 @@ use std::{ fmt::{self, Write}, mem, + net::{IpAddr, SocketAddr}, ops::{Range, RangeInclusive}, }; @@ -134,6 +135,9 @@ frame_types! { ACK_FREQUENCY = 0xaf, IMMEDIATE_ACK = 0x1f, // DATAGRAM + // ADDRESS DISCOVERY REPORT + OBSERVED_IPV4_ADDR = 0x9f81a6, + OBSERVED_IPV6_ADDR = 0x9f81a7, } const STREAM_TYS: RangeInclusive = RangeInclusive::new(0x08, 0x0f); @@ -164,6 +168,7 @@ pub(crate) enum Frame { AckFrequency(AckFrequency), ImmediateAck, HandshakeDone, + ObservedAddr(ObservedAddr), } impl Frame { @@ -205,6 +210,7 @@ impl Frame { AckFrequency(_) => FrameType::ACK_FREQUENCY, ImmediateAck => FrameType::IMMEDIATE_ACK, HandshakeDone => FrameType::HANDSHAKE_DONE, + ObservedAddr(ref observed) => observed.get_type(), } } @@ -682,6 +688,11 @@ impl Iter { reordering_threshold: self.bytes.get()?, }), FrameType::IMMEDIATE_ACK => Frame::ImmediateAck, + FrameType::OBSERVED_IPV4_ADDR | FrameType::OBSERVED_IPV6_ADDR => { + let is_ipv6 = ty == FrameType::OBSERVED_IPV6_ADDR; + let observed = ObservedAddr::read(&mut self.bytes, is_ipv6)?; + Frame::ObservedAddr(observed) + } _ => { if let Some(s) = ty.stream() { Frame::Stream(Stream { @@ -922,8 +933,87 @@ impl AckFrequency { } } +/* Address Discovery https://datatracker.ietf.org/doc/draft-seemann-quic-address-discovery/ */ + +/// Conjuction of the information contained in the address discovery frames +/// ([`FrameType::OBSERVED_IPV4_ADDR`], [`FrameType::OBSERVED_IPV6_ADDR`]). +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub(crate) struct ObservedAddr { + /// Monotonically increasing integer within the same connection. + pub(crate) seq_no: VarInt, + /// Reported observed address. + pub(crate) ip: IpAddr, + /// Reported observed port. + pub(crate) port: u16, +} + +impl ObservedAddr { + pub(crate) fn new(remote: SocketAddr, seq_no: VarInt) -> Self { + Self { + ip: remote.ip(), + port: remote.port(), + seq_no, + } + } + + /// Get the [`FrameType`] for this frame. + pub(crate) fn get_type(&self) -> FrameType { + if self.ip.is_ipv6() { + FrameType::OBSERVED_IPV6_ADDR + } else { + FrameType::OBSERVED_IPV4_ADDR + } + } + + /// Compute the number of bytes needed to encode the frame. + pub(crate) fn size(&self) -> usize { + let type_size = VarInt(self.get_type().0).size(); + let req_id_bytes = self.seq_no.size(); + let ip_bytes = if self.ip.is_ipv6() { 16 } else { 4 }; + let port_bytes = 2; + type_size + req_id_bytes + ip_bytes + port_bytes + } + + /// Unconditionally write this frame to `buf`. + pub(crate) fn write(&self, buf: &mut W) { + buf.write(self.get_type()); + buf.write(self.seq_no); + match self.ip { + IpAddr::V4(ipv4_addr) => { + buf.write(ipv4_addr); + } + IpAddr::V6(ipv6_addr) => { + buf.write(ipv6_addr); + } + } + buf.write::(self.port); + } + + /// Reads the frame contents from the buffer. + /// + /// Should only be called when the fram type has been identified as + /// [`FrameType::OBSERVED_IPV4_ADDR`] or [`FrameType::OBSERVED_IPV6_ADDR`]. + pub(crate) fn read(bytes: &mut R, is_ipv6: bool) -> coding::Result { + Ok(Self { + seq_no: bytes.get()?, + ip: if is_ipv6 { + IpAddr::V6(bytes.get()?) + } else { + IpAddr::V4(bytes.get()?) + }, + port: bytes.get()?, + }) + } + + /// Gives the [`SocketAddr`] reported in the frame. + pub(crate) fn socket_addr(&self) -> SocketAddr { + (self.ip, self.port).into() + } +} + #[cfg(test)] mod test { + use super::*; use crate::coding::Codec; use assert_matches::assert_matches; @@ -988,4 +1078,29 @@ mod test { assert_eq!(frames.len(), 1); assert_matches!(&frames[0], Frame::ImmediateAck); } + + /// Test that encoding and decoding [`ObservedAddr`] produces the same result. + #[test] + fn test_observed_addr_roundrip() { + let observed_addr = ObservedAddr { + seq_no: VarInt(42), + ip: std::net::Ipv4Addr::LOCALHOST.into(), + port: 4242, + }; + let mut buf = Vec::with_capacity(observed_addr.size()); + observed_addr.write(&mut buf); + + assert_eq!( + observed_addr.size(), + buf.len(), + "expected written bytes and actual size differ" + ); + + let mut decoded = frames(buf); + assert_eq!(decoded.len(), 1); + match decoded.pop().expect("non empty") { + Frame::ObservedAddr(decoded) => assert_eq!(decoded, observed_addr), + x => panic!("incorrect frame {x:?}"), + } + } } diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 12051a62f5..270f31ae4f 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -87,6 +87,8 @@ pub use crate::cid_generator::{ mod token; use token::ResetToken; +mod address_discovery; + #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index ac254d555c..fcde373725 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -3184,6 +3184,306 @@ fn voluntary_ack_with_large_datagrams() { ); } +/// Test the address discovery extension on a normal setup. +#[test] +fn address_discovery() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..client_config() + }; + let conn_handle = pair.begin_connect(client_config); + + // wait for idle connections + pair.drive(); + + // check that the client received the correct address + let expected_addr = pair.client.addr; + let conn = pair.client_conn_mut(conn_handle); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), Some(Event::ObservedAddr(addr)) if addr == expected_addr); + assert_matches!(conn.poll(), None); + + // check that the server received the correct address + let conn_handle = pair.server.assert_accept(); + let expected_addr = pair.server.addr; + let conn = pair.server_conn_mut(conn_handle); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), Some(Event::ObservedAddr(addr)) if addr == expected_addr); + assert_matches!(conn.poll(), None); +} + +/// Test that a different address discovery configuration on 0rtt used by the client is accepted by +/// the server. +/// NOTE: this test is the same as zero_rtt_happypath, changing client transport parameters on +/// resumption. +#[test] +fn address_discovery_zero_rtt_accepted() { + let _guard = subscribe(); + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + + pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + let client_cfg = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..client_config() + }; + let alt_client_cfg = ClientConfig { + transport: Arc::new(TransportConfig::default()), + ..client_cfg.clone() + }; + + // Establish normal connection + let client_ch = pair.begin_connect(client_cfg); + pair.drive(); + pair.server.assert_accept(); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(0), [][..].into()); + pair.drive(); + + pair.client.addr = SocketAddr::new( + Ipv6Addr::LOCALHOST.into(), + CLIENT_PORTS.lock().unwrap().next().unwrap(), + ); + info!("resuming session"); + let client_ch = pair.begin_connect(alt_client_cfg); + assert!(pair.client_conn_mut(client_ch).has_0rtt()); + let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); + const MSG: &[u8] = b"Hello, 0-RTT!"; + pair.client_send(client_ch, s).write(MSG).unwrap(); + pair.drive(); + + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + + assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); + let server_ch = pair.server.assert_accept(); + + let conn = pair.server_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + // We don't currently preserve stream event order wrt. connection events + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!( + conn.poll(), + Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) + ); + + let mut recv = pair.server_recv(server_ch, s); + let mut chunks = recv.read(false).unwrap(); + assert_matches!( + chunks.next(usize::MAX), + Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG + ); + let _ = chunks.finalize(); + assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); +} + +/// Test that a different address discovery configuration on 0rtt used by the server is rejected by +/// the client. +/// NOTE: the server MUST not change configuration on resumption. However, there is no designed +/// behaviour when this is encountered. Quinn chooses to accept and then close the connection, +/// which is what this test checks. +#[test] +fn address_discovery_zero_rtt_rejection() { + let _guard = subscribe(); + let server_cfg = ServerConfig { + transport: Default::default(), + ..server_config() + }; + let alt_server_cfg = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + ..Default::default() + }, + ..TransportConfig::default() + }), + ..server_cfg.clone() + }; + let mut pair = Pair::new(Default::default(), server_cfg); + let client_cfg = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..client_config() + }; + + // Establish normal connection + let client_ch = pair.begin_connect(client_cfg.clone()); + pair.drive(); + let server_ch = pair.server.assert_accept(); + let conn = pair.server_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(0), [][..].into()); + pair.drive(); + assert_matches!( + pair.server_conn_mut(server_ch).poll(), + Some(Event::ConnectionLost { .. }) + ); + assert_matches!(pair.server_conn_mut(server_ch).poll(), None); + pair.client.connections.clear(); + pair.server.connections.clear(); + + // Changing address discovery configurations makes the client close the connection + pair.server + .set_server_config(Some(Arc::new(alt_server_cfg))); + info!("resuming session"); + let client_ch = pair.begin_connect(client_cfg); + assert!(pair.client_conn_mut(client_ch).has_0rtt()); + let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); + const MSG: &[u8] = b"Hello, 0-RTT!"; + pair.client_send(client_ch, s).write(MSG).unwrap(); + pair.drive(); + let conn = pair.client_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!( + conn.poll(), + Some(Event::ConnectionLost { + reason: ConnectionError::TransportError(_) + }) + ); +} + +#[test] +fn address_discovery_retransmission() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..client_config() + }; + let client_ch = pair.begin_connect(client_config); + pair.step(); + + // lose the last packet + pair.client.inbound.pop_back().unwrap(); + pair.step(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + + pair.drive(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), + Some(Event::ObservedAddr(addr)) if addr == pair.client.addr); +} + +#[test] +fn address_discovery_rebind_retransmission() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role { + send_reports: true, + receive_reports: true, + }, + ..TransportConfig::default() + }), + ..client_config() + }; + let client_ch = pair.begin_connect(client_config); + pair.step(); + + // lose the last packet + pair.client.inbound.pop_back().unwrap(); + pair.step(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + + // simulate a rebind to ensure we will get an updated address instead of retransmitting + // outdated info + pair.client_conn_mut(client_ch).local_address_changed(); + pair.client + .addr + .set_port(pair.client.addr.port().overflowing_add(1).0); + + pair.drive(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), + Some(Event::ObservedAddr(addr)) if addr == pair.client.addr); +} + #[test] fn reject_short_idcid() { let _guard = subscribe(); diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index ccf378d4a7..7129132edd 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -16,6 +16,7 @@ use rand::{seq::SliceRandom as _, Rng as _, RngCore}; use thiserror::Error; use crate::{ + address_discovery, cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, @@ -110,6 +111,9 @@ macro_rules! make_struct { /// This field is initialized only for outgoing `TransportParameters` instances and /// is set to `None` for `TransportParameters` received from a peer. pub(crate) write_order: Option<[u8; TransportParameterId::SUPPORTED.len()]>, + + /// The role of this peer in address discovery, if any. + pub(crate) address_discovery_role: address_discovery::Role, } // We deliberately don't implement the `Default` trait, since that would be public, and @@ -133,6 +137,8 @@ macro_rules! make_struct { preferred_address: None, grease_transport_parameter: None, write_order: None, + + address_discovery_role: address_discovery::Role::default(), } } } @@ -180,6 +186,7 @@ impl TransportParameters { order.shuffle(rng); order }), + address_discovery_role: config.address_discovery_role, ..Self::default() } } @@ -196,6 +203,7 @@ impl TransportParameters { || cached.initial_max_streams_uni > self.initial_max_streams_uni || cached.max_datagram_frame_size > self.max_datagram_frame_size || cached.grease_quic_bit && !self.grease_quic_bit + || cached.address_discovery_role != self.address_discovery_role { return Err(TransportError::PROTOCOL_VIOLATION( "0-RTT accepted with incompatible transport parameters", @@ -380,6 +388,14 @@ impl TransportParameters { w.write(x); } } + TransportParameterId::ObservedAddr => { + if let Some(varint_role) = self.address_discovery_role.as_transport_parameter() + { + w.write_var(id as u64); + w.write_var(varint_role.size() as u64); + w.write(varint_role); + } + } id => { macro_rules! write_params { {$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => { @@ -478,6 +494,15 @@ impl TransportParameters { TransportParameterId::MinAckDelayDraft07 => { params.min_ack_delay = Some(r.get().unwrap()) } + TransportParameterId::ObservedAddr => { + let prev_role = ¶ms.address_discovery_role; + params.address_discovery_role = + address_discovery::Role::from_transport_parameter(len, prev_role, r)?; + tracing::debug!( + role = ?params.address_discovery_role, + "address discovery enabled for peer" + ); + } _ => { macro_rules! parse { {$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => { @@ -640,11 +665,14 @@ pub(crate) enum TransportParameterId { // https://datatracker.ietf.org/doc/html/draft-ietf-quic-ack-frequency#section-10.1 MinAckDelayDraft07 = 0xFF04DE1B, + + // + ObservedAddr = 0x9f81a176, } impl TransportParameterId { /// Array with all supported transport parameter IDs - const SUPPORTED: [Self; 21] = [ + const SUPPORTED: [Self; 22] = [ Self::MaxIdleTimeout, Self::MaxUdpPayloadSize, Self::InitialMaxData, @@ -666,6 +694,7 @@ impl TransportParameterId { Self::RetrySourceConnectionId, Self::GreaseQuicBit, Self::MinAckDelayDraft07, + Self::ObservedAddr, ]; } @@ -705,6 +734,7 @@ impl TryFrom for TransportParameterId { id if Self::RetrySourceConnectionId == id => Self::RetrySourceConnectionId, id if Self::GreaseQuicBit == id => Self::GreaseQuicBit, id if Self::MinAckDelayDraft07 == id => Self::MinAckDelayDraft07, + id if Self::ObservedAddr == id => Self::ObservedAddr, _ => return Err(()), }; Ok(param) @@ -742,6 +772,10 @@ mod test { }), grease_quic_bit: true, min_ack_delay: Some(2_000u32.into()), + address_discovery_role: address_discovery::Role { + send_reports: true, + ..Default::default() + }, ..TransportParameters::default() }; params.write(&mut buf); diff --git a/quinn-proto/src/varint.rs b/quinn-proto/src/varint.rs index a72fb3431f..08022d9db3 100644 --- a/quinn-proto/src/varint.rs +++ b/quinn-proto/src/varint.rs @@ -50,6 +50,14 @@ impl VarInt { self.0 } + /// Saturating integer addition. Computes self + rhs, saturating at the numeric bounds instead + /// of overflowing. + pub fn saturating_add(self, rhs: impl Into) -> Self { + let rhs = rhs.into(); + let inner = self.0.saturating_add(rhs.0).min(Self::MAX.0); + Self(inner) + } + /// Compute the number of bytes needed to encode this value pub(crate) const fn size(self) -> usize { let x = self.0; @@ -191,3 +199,19 @@ impl Codec for VarInt { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_saturating_add() { + // add within range behaves normally + let large: VarInt = u32::MAX.into(); + let next = u64::from(u32::MAX) + 1; + assert_eq!(large.saturating_add(1u8), VarInt::from_u64(next).unwrap()); + + // outside range saturates + assert_eq!(VarInt::MAX.saturating_add(1u8), VarInt::MAX) + } +} diff --git a/quinn/examples/client.rs b/quinn/examples/client.rs index 0ace61f957..80fc3562d7 100644 --- a/quinn/examples/client.rs +++ b/quinn/examples/client.rs @@ -13,7 +13,7 @@ use std::{ use anyhow::{anyhow, Result}; use clap::Parser; -use proto::crypto::rustls::QuicClientConfig; +use proto::{crypto::rustls::QuicClientConfig, TransportConfig}; use rustls::pki_types::CertificateDer; use tracing::{error, info}; use url::Url; @@ -101,8 +101,13 @@ async fn run(options: Opt) -> Result<()> { client_crypto.key_log = Arc::new(rustls::KeyLogFile::new()); } - let client_config = + let mut transport = TransportConfig::default(); + transport + .send_observed_address_reports(true) + .receive_observed_address_reports(true); + let mut client_config = quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto)?)); + client_config.transport_config(Arc::new(transport)); let mut endpoint = quinn::Endpoint::client(options.bind)?; endpoint.set_default_client_config(client_config); @@ -117,6 +122,18 @@ async fn run(options: Opt) -> Result<()> { .await .map_err(|e| anyhow!("failed to connect: {}", e))?; eprintln!("connected at {:?}", start.elapsed()); + let mut external_addresses = conn.observed_external_addr(); + tokio::spawn(async move { + loop { + if let Some(new_addr) = *external_addresses.borrow_and_update() { + info!(%new_addr, "new external address report"); + } + if external_addresses.changed().await.is_err() { + break; + } + } + }); + let (mut send, mut recv) = conn .open_bi() .await diff --git a/quinn/examples/server.rs b/quinn/examples/server.rs index b6f63160e6..b65d739bee 100644 --- a/quinn/examples/server.rs +++ b/quinn/examples/server.rs @@ -127,7 +127,10 @@ async fn run(options: Opt) -> Result<()> { let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(server_crypto)?)); let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); - transport_config.max_concurrent_uni_streams(0_u8.into()); + transport_config + .max_concurrent_uni_streams(0_u8.into()) + .send_observed_address_reports(true) + .receive_observed_address_reports(true); let root = Arc::::from(options.root.clone()); if !root.exists() { @@ -176,6 +179,21 @@ async fn handle_connection(root: Arc, conn: quinn::Incoming) -> Result<()> .protocol .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) ); + + let mut external_addresses = connection.observed_external_addr(); + tokio::spawn( + async move { + loop { + if let Some(new_addr) = *external_addresses.borrow_and_update() { + info!(%new_addr, "new external address report"); + } + if external_addresses.changed().await.is_err() { + break; + } + } + } + .instrument(span.clone()), + ); async { info!("established"); diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index e08a4fdc0c..b92efae116 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -14,7 +14,7 @@ use bytes::Bytes; use pin_project_lite::pin_project; use rustc_hash::FxHashMap; use thiserror::Error; -use tokio::sync::{futures::Notified, mpsc, oneshot, Notify}; +use tokio::sync::{futures::Notified, mpsc, oneshot, watch, Notify}; use tracing::{debug_span, Instrument, Span}; use crate::{ @@ -636,6 +636,12 @@ impl Connection { // May need to send MAX_STREAMS to make progress conn.wake(); } + + /// Track changed on our external address as reported by the peer. + pub fn observed_external_addr(&self) -> watch::Receiver> { + let conn = self.0.state.lock("external_addr"); + conn.observed_external_addr.subscribe() + } } pin_project! { @@ -892,6 +898,7 @@ impl ConnectionRef { runtime, send_buffer: Vec::new(), buffered_transmit: None, + observed_external_addr: watch::Sender::new(None), }), shared: Shared::default(), })) @@ -974,6 +981,8 @@ pub(crate) struct State { send_buffer: Vec, /// We buffer a transmit when the underlying I/O would block buffered_transmit: Option, + /// Our last external address reported by the peer. + observed_external_addr: watch::Sender>, } impl State { @@ -1131,6 +1140,12 @@ impl State { wake_stream(id, &mut self.stopped); wake_stream(id, &mut self.blocked_writers); } + ObservedAddr(observed) => { + self.observed_external_addr.send_if_modified(|addr| { + let old = addr.replace(observed); + old != *addr + }); + } } } }