diff --git a/Cargo.lock b/Cargo.lock index f815e1e..600d725 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,9 +43,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "anymap2" @@ -655,6 +655,12 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "http" version = "0.2.12" @@ -1101,6 +1107,18 @@ dependencies = [ "web-time", ] +[[package]] +name = "nested_enum_utils" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aa9a338d2f55df2c5f4ddd2789115e8a16ba9f363b2c551a8f9b0695d23bdc2" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "netdev" version = "0.31.0" @@ -1208,6 +1226,7 @@ dependencies = [ "js-sys", "libc", "n0-future", + "nested_enum_utils", "netdev", "netlink-packet-core", "netlink-packet-route 0.19.0", @@ -1215,9 +1234,9 @@ dependencies = [ "rtnetlink 0.13.1", "rtnetlink 0.14.1", "serde", + "snafu", "socket2", "testresult", - "thiserror 2.0.11", "time", "tokio", "tokio-util", @@ -1438,6 +1457,7 @@ dependencies = [ "igd-next", "iroh-metrics", "libc", + "nested_enum_utils", "netwatch", "ntest", "num_enum", @@ -1445,8 +1465,8 @@ dependencies = [ "rand_chacha", "serde", "smallvec", + "snafu", "socket2", - "thiserror 2.0.11", "time", "tokio", "tokio-util", @@ -1777,6 +1797,27 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "snafu" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "socket2" version = "0.5.8" diff --git a/netwatch/Cargo.toml b/netwatch/Cargo.toml index fec8af7..84fece8 100644 --- a/netwatch/Cargo.toml +++ b/netwatch/Cargo.toml @@ -19,7 +19,8 @@ workspace = true atomic-waker = "1.1.2" bytes = "1.7" n0-future = "0.1.1" -thiserror = "2" +nested_enum_utils = "0.2.0" +snafu = "0.8.5" time = "0.3.20" tokio = { version = "1", features = [ "io-util", @@ -60,6 +61,7 @@ netlink-packet-core = "0.7.0" netlink-packet-route = "0.19" # 0.20/21 is blocked on rtnetlink bumping its dependency netlink-sys = "0.8.6" rtnetlink = "=0.13.1" # pinned because of https://github.com/rust-netlink/rtnetlink/issues/83 +derive_more = { version = "1.0.0", features = ["display"] } [target.'cfg(target_os = "windows")'.dependencies] wmi = "0.14" @@ -68,7 +70,9 @@ windows-result = "0.3" serde = { version = "1", features = ["derive"] } derive_more = { version = "1.0.0", features = ["debug"] } +# wasm-in-browser dependencies [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies] +derive_more = { version = "1.0.0", features = ["display"] } js-sys = "0.3" web-sys = { version = "0.3.70", features = ["EventListener", "EventTarget"] } diff --git a/netwatch/src/interfaces/bsd.rs b/netwatch/src/interfaces/bsd.rs index 5826f9a..5097b86 100644 --- a/netwatch/src/interfaces/bsd.rs +++ b/netwatch/src/interfaces/bsd.rs @@ -13,6 +13,8 @@ use libc::{c_int, uintptr_t, AF_INET, AF_INET6, AF_LINK, AF_ROUTE, AF_UNSPEC, CT use libc::{ NET_RT_DUMP, RTAX_BRD, RTAX_DST, RTAX_GATEWAY, RTAX_MAX, RTAX_NETMASK, RTA_IFP, RTF_GATEWAY, }; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, IntoError, OptionExt, Snafu}; use tracing::warn; use super::DefaultRouteDetails; @@ -247,7 +249,7 @@ fn u16_from_ne_range( data.get(range) .and_then(|s| TryInto::<[u8; 2]>::try_into(s).ok()) .map(u16::from_ne_bytes) - .ok_or(RouteError::MessageTooShort) + .context(MessageTooShortSnafu) } /// Safely convert some bytes from a slice into a u32. @@ -258,7 +260,7 @@ fn u32_from_ne_range( data.get(range) .and_then(|s| TryInto::<[u8; 4]>::try_into(s).ok()) .map(u32::from_ne_bytes) - .ok_or(RouteError::MessageTooShort) + .context(MessageTooShortSnafu) } impl WireFormat { @@ -271,16 +273,12 @@ impl WireFormat { target_os = "ios" ))] MessageType::Route => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, ..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); let attrs: i32 = u32_from_ne_range(data, 12..16)? .try_into() - .map_err(|_| RouteError::InvalidMessage)?; + .map_err(|_| InvalidMessageSnafu.build())?; let addrs = parse_addrs(attrs, parse_kernel_inet_addr, &data[self.body_off..])?; let mut m = RouteMessage { version: data[2] as _, @@ -302,17 +300,11 @@ impl WireFormat { } #[cfg(target_os = "openbsd")] MessageType::Route => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, ..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); let ll = u16_from_ne_range(data, 4..6)? as usize; - if data.len() < ll { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= ll as usize, InvalidMessageSnafu); let addrs = parse_addrs( u32_from_ne_range(data, 12..16)? as _, @@ -339,13 +331,9 @@ impl WireFormat { Ok(Some(WireMessage::Route(m))) } MessageType::Interface => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, 0..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); let attrs = u32_from_ne_range(data, 4..8)?; if attrs as c_int & RTA_IFP == 0 { @@ -366,13 +354,9 @@ impl WireFormat { Ok(Some(WireMessage::Interface(m))) } MessageType::InterfaceAddr => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, ..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); #[cfg(target_os = "netbsd")] let index = u16_from_ne_range(data, 16..18)?; @@ -395,13 +379,10 @@ impl WireFormat { Ok(Some(WireMessage::InterfaceAddr(m))) } MessageType::InterfaceMulticastAddr => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, ..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); + let addrs = parse_addrs( u32_from_ne_range(data, 4..8)? as _, parse_kernel_inet_addr, @@ -417,13 +398,9 @@ impl WireFormat { Ok(Some(WireMessage::InterfaceMulticastAddr(m))) } MessageType::InterfaceAnnounce => { - if data.len() < self.body_off { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(data.len() >= self.body_off, MessageTooShortSnafu); let l = u16_from_ne_range(data, ..2)?; - if data.len() < l as usize { - return Err(RouteError::InvalidMessage); - } + snafu::ensure!(data.len() >= l as usize, InvalidMessageSnafu); let mut name = String::new(); for i in 0..16 { @@ -431,7 +408,7 @@ impl WireFormat { continue; } name = std::str::from_utf8(&data[6..6 + i]) - .map_err(|_| RouteError::InvalidAddress)? + .map_err(|_| InvalidAddressSnafu.build())? .to_string(); break; } @@ -469,9 +446,10 @@ struct RoutingStack { /// Parses b as a routing information base and returns a list of routing messages. pub fn parse_rib(typ: RIBType, data: &[u8]) -> Result, RouteError> { - if !is_valid_rib_type(typ) { - return Err(RouteError::InvalidRibType(typ)); - } + snafu::ensure!( + is_valid_rib_type(typ), + InvalidRibTypeSnafu { rib_type: typ } + ); let mut msgs = Vec::new(); let mut nmsgs = 0; @@ -481,12 +459,8 @@ pub fn parse_rib(typ: RIBType, data: &[u8]) -> Result, RouteErr while b.len() > 4 { nmsgs += 1; let l = u16_from_ne_range(b, ..2)?; - if l == 0 { - return Err(RouteError::InvalidMessage); - } - if b.len() < l as usize { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(l != 0, InvalidMessageSnafu); + snafu::ensure!(b.len() >= l as usize, MessageTooShortSnafu); if b[2] as i32 != ROUTING_STACK.rtm_version { // b = b[l:]; continue; @@ -511,9 +485,7 @@ pub fn parse_rib(typ: RIBType, data: &[u8]) -> Result, RouteErr } // We failed to parse any of the messages - version mismatch? - if nmsgs != msgs.len() + nskips { - return Err(RouteError::MessageMismatch); - } + snafu::ensure!(nmsgs == msgs.len() + nskips, MessageMismatchSnafu); Ok(msgs) } @@ -627,20 +599,27 @@ pub struct InterfaceAnnounceMessage { /// Represents a type of routing information base. type RIBType = i32; -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum RouteError { - #[error("message mismatch")] - MessageMismatch, - #[error("message too short")] - MessageTooShort, - #[error("invalid message")] - InvalidMessage, - #[error("invalid address")] - InvalidAddress, - #[error("invalid rib type: {0}")] - InvalidRibType(RIBType), - #[error("io error calling: '{0}': {1:?}")] - Io(&'static str, std::io::Error), + #[snafu(display("message mismatch"))] + MessageMismatch {}, + #[snafu(display("message too short"))] + MessageTooShort {}, + #[snafu(display("invalid message"))] + InvalidMessage {}, + #[snafu(display("invalid address"))] + InvalidAddress {}, + #[snafu(display("invalid rib type {rib_type}"))] + InvalidRibType { rib_type: RIBType }, + #[snafu(display("io error calling '{name}'"))] + Io { + source: std::io::Error, + name: &'static str, + }, } /// FetchRIB fetches a routing information base from the operating system. @@ -670,7 +649,7 @@ fn fetch_rib(af: i32, typ: RIBType, arg: i32) -> Result, RouteError> { ) }; if err != 0 { - return Err(RouteError::Io("sysctl", std::io::Error::last_os_error())); + return Err(IoSnafu { name: "sysctl" }.into_error(std::io::Error::last_os_error())); } if n == 0 { // nothing available @@ -696,7 +675,7 @@ fn fetch_rib(af: i32, typ: RIBType, arg: i32) -> Result, RouteError> { if io_err.raw_os_error().unwrap_or_default() == libc::ENOMEM && round < MAX_TRIES { continue; } - return Err(RouteError::Io("sysctl", io_err)); + return Err(IoSnafu { name: "sysctl" }.into_error(io_err)); } // Truncate b, to the new length b.truncate(n); @@ -789,9 +768,7 @@ where let a = parse_link_addr(b)?; addrs.push(a); let l = roundup(b[0] as usize); - if b.len() < l { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(b.len() >= l, MessageTooShortSnafu); b = &b[l..]; } AF_INET | AF_INET6 => { @@ -799,9 +776,7 @@ where let a = parse_inet_addr(af, b)?; addrs.push(a); let l = roundup(b[0] as usize); - if b.len() < l { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(b.len() >= l, MessageTooShortSnafu); b = &b[l..]; } _ => { @@ -819,9 +794,7 @@ where let a = parse_default_addr(b)?; addrs.push(a); let l = roundup(b[0] as usize); - if b.len() < l { - return Err(RouteError::MessageTooShort); - } + snafu::ensure!(b.len() >= l, MessageTooShortSnafu); b = &b[l..]; } } @@ -836,23 +809,19 @@ where fn parse_inet_addr(af: i32, b: &[u8]) -> Result { match af { AF_INET => { - if b.len() < SIZEOF_SOCKADDR_INET { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!(b.len() >= SIZEOF_SOCKADDR_INET, InvalidAddressSnafu); let ip = Ipv4Addr::new(b[4], b[5], b[6], b[7]); Ok(Addr::Inet4 { ip }) } AF_INET6 => { - if b.len() < SIZEOF_SOCKADDR_INET6 { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!(b.len() >= SIZEOF_SOCKADDR_INET6, InvalidAddressSnafu); let mut zone = u32_from_ne_range(b, 24..28)?; let mut oc: [u8; 16] = b .get(8..24) .and_then(|s| TryInto::<[u8; 16]>::try_into(s).ok()) - .ok_or(RouteError::InvalidMessage)?; + .context(InvalidMessageSnafu)?; if oc[0] == 0xfe && oc[1] & 0xc0 == 0x80 || oc[0] == 0xff && (oc[1] & 0x0f == 0x01 || oc[1] & 0x0f == 0x02) { @@ -865,7 +834,7 @@ fn parse_inet_addr(af: i32, b: &[u8]) -> Result { .get(2..4) .and_then(|s| TryInto::<[u8; 2]>::try_into(s).ok()) .map(u16::from_be_bytes) - .ok_or(RouteError::InvalidMessage)? as u32; + .context(InvalidMessageSnafu)? as u32; if id != 0 { zone = id; oc[2] = 0; @@ -877,7 +846,7 @@ fn parse_inet_addr(af: i32, b: &[u8]) -> Result { zone, }) } - _ => Err(RouteError::InvalidAddress), + _ => Err(InvalidAddressSnafu.build()), } } @@ -916,9 +885,7 @@ fn parse_kernel_inet_addr(af: i32, b: &[u8]) -> Result<(i32, Addr), RouteError> l = roundup(l); } - if b.len() < l { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!(b.len() >= l, InvalidAddressSnafu); // Don't reorder case expressions. // The case expressions for IPv6 must come first. const OFF4: usize = 4; // offset of in_addr @@ -928,7 +895,7 @@ fn parse_kernel_inet_addr(af: i32, b: &[u8]) -> Result<(i32, Addr), RouteError> let octets: [u8; 16] = b .get(OFF6..OFF6 + 16) .and_then(|s| TryInto::try_into(s).ok()) - .ok_or(RouteError::InvalidMessage)?; + .context(InvalidMessageSnafu)?; let ip = Ipv6Addr::from(octets); Addr::Inet6 { ip, zone: 0 } } else if af == AF_INET6 { @@ -944,7 +911,7 @@ fn parse_kernel_inet_addr(af: i32, b: &[u8]) -> Result<(i32, Addr), RouteError> let octets: [u8; 4] = b .get(OFF4..OFF4 + 4) .and_then(|s| TryInto::try_into(s).ok()) - .ok_or(RouteError::InvalidMessage)?; + .context(InvalidMessageSnafu)?; let ip = Ipv4Addr::from(octets); Addr::Inet4 { ip } } else { @@ -963,9 +930,7 @@ fn parse_kernel_inet_addr(af: i32, b: &[u8]) -> Result<(i32, Addr), RouteError> } fn parse_link_addr(b: &[u8]) -> Result { - if b.len() < 8 { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!(b.len() >= 8, InvalidAddressSnafu); let (_, mut a) = parse_kernel_link_addr(AF_LINK, &b[4..])?; if let Addr::Link { index, .. } = &mut a { @@ -1007,14 +972,12 @@ fn parse_kernel_link_addr(_: i32, b: &[u8]) -> Result<(usize, Addr), RouteError> } let l = 4 + nlen + alen + slen; - if b.len() < l { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!(b.len() >= l, InvalidAddressSnafu); let mut data = &b[4..]; let name = if nlen > 0 { let name = std::str::from_utf8(&data[..nlen]) - .map_err(|_| RouteError::InvalidAddress)? + .map_err(|_| InvalidAddressSnafu.build())? .to_string(); data = &data[nlen..]; Some(name) @@ -1038,9 +1001,10 @@ fn parse_kernel_link_addr(_: i32, b: &[u8]) -> Result<(usize, Addr), RouteError> } fn parse_default_addr(b: &[u8]) -> Result { - if b.len() < 2 || b.len() < b[0] as usize { - return Err(RouteError::InvalidAddress); - } + snafu::ensure!( + b.len() >= 2 && b.len() >= b[0] as usize, + InvalidAddressSnafu + ); Ok(Addr::Default { af: b[1] as _, raw: b[..b[0] as usize].to_vec().into_boxed_slice(), diff --git a/netwatch/src/interfaces/linux.rs b/netwatch/src/interfaces/linux.rs index 712e727..b68e83d 100644 --- a/netwatch/src/interfaces/linux.rs +++ b/netwatch/src/interfaces/linux.rs @@ -2,6 +2,8 @@ #[cfg(not(target_os = "android"))] use n0_future::TryStreamExt; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use tokio::{ fs::File, io::{AsyncBufReadExt, BufReader}, @@ -9,25 +11,29 @@ use tokio::{ use super::DefaultRouteDetails; -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("IO {0}")] - Io(#[from] std::io::Error), + #[snafu(display("IO"))] + Io { source: std::io::Error }, #[cfg(not(target_os = "android"))] - #[error("no netlink response")] - NoResponse, + #[snafu(display("no netlink response"))] + NoResponse {}, #[cfg(not(target_os = "android"))] - #[error("interface not found")] - InterfaceNotFound, - #[error("iface field is missing")] - MissingIfaceField, - #[error("destination field is missing")] - MissingDestinationField, - #[error("mask field is missing")] - MissingMaskField, + #[snafu(display("interface not found"))] + InterfaceNotFound {}, + #[snafu(display("iface field is missing"))] + MissingIfaceField {}, + #[snafu(display("destination field is missing"))] + MissingDestinationField {}, + #[snafu(display("mask field is missing"))] + MissingMaskField {}, #[cfg(not(target_os = "android"))] - #[error("netlink")] - Netlink(#[from] rtnetlink::Error), + #[snafu(display("netlink"))] + Netlink { source: rtnetlink::Error }, } pub async fn default_route() -> Option { @@ -49,7 +55,7 @@ const PROC_NET_ROUTE_PATH: &str = "/proc/net/route"; async fn default_route_proc() -> Result, Error> { const ZERO_ADDR: &str = "00000000"; - let file = File::open(PROC_NET_ROUTE_PATH).await?; + let file = File::open(PROC_NET_ROUTE_PATH).await.context(IoSnafu)?; // Explicitly set capacity, this is min(4096, DEFAULT_BUF_SIZE): // https://github.com/google/gvisor/issues/5732 @@ -65,14 +71,14 @@ async fn default_route_proc() -> Result, Error> { // read it all in one call. let reader = BufReader::with_capacity(8 * 1024, file); let mut lines_iter = reader.lines(); - while let Some(line) = lines_iter.next_line().await? { + while let Some(line) = lines_iter.next_line().await.context(IoSnafu)? { if !line.contains(ZERO_ADDR) { continue; } let mut fields = line.split_ascii_whitespace(); - let iface = fields.next().ok_or(Error::MissingIfaceField)?; - let destination = fields.next().ok_or(Error::MissingDestinationField)?; - let mask = fields.nth(5).ok_or(Error::MissingMaskField)?; + let iface = fields.next().context(MissingIfaceFieldSnafu)?; + let destination = fields.next().context(MissingDestinationFieldSnafu)?; + let mask = fields.nth(5).context(MissingMaskFieldSnafu)?; // if iface.starts_with("tailscale") || iface.starts_with("wg") { // continue; // } @@ -97,7 +103,8 @@ pub async fn default_route_android_ip_route() -> Result Option<&str> { async fn default_route_netlink() -> Result, Error> { use tracing::{info_span, Instrument}; - let (connection, handle, _receiver) = rtnetlink::new_connection()?; + let (connection, handle, _receiver) = rtnetlink::new_connection().context(IoSnafu)?; let task = tokio::spawn(connection.instrument(info_span!("rtnetlink.conn"))); let default = default_route_netlink_family(&handle, rtnetlink::IpVersion::V4).await?; @@ -151,7 +158,7 @@ async fn default_route_netlink_family( use netlink_packet_route::route::RouteAttribute; let mut routes = handle.route().get(family).execute(); - while let Some(route) = routes.try_next().await? { + while let Some(route) = routes.try_next().await.context(NetlinkSnafu)? { let route_attrs = route.attributes; if !route_attrs @@ -187,16 +194,21 @@ async fn default_route_netlink_family( #[cfg(not(target_os = "android"))] async fn iface_by_index(handle: &rtnetlink::Handle, index: u32) -> Result { use netlink_packet_route::link::LinkAttribute; + use snafu::OptionExt; let mut links = handle.link().get().match_index(index).execute(); - let msg = links.try_next().await?.ok_or(Error::NoResponse)?; + let msg = links + .try_next() + .await + .context(NetlinkSnafu)? + .context(NoResponseSnafu)?; for nla in msg.attributes { if let LinkAttribute::IfName(name) = nla { return Ok(name); } } - Err(Error::InterfaceNotFound) + Err(InterfaceNotFoundSnafu.build()) } #[cfg(test)] diff --git a/netwatch/src/interfaces/windows.rs b/netwatch/src/interfaces/windows.rs index e796824..8e14048 100644 --- a/netwatch/src/interfaces/windows.rs +++ b/netwatch/src/interfaces/windows.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; +use nested_enum_utils::common_fields; use serde::Deserialize; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use tracing::warn; use wmi::{query::FilterValue, COMLibrary, WMIConnection}; @@ -13,26 +15,32 @@ struct Win32_IP4RouteTable { Name: String, } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("IO {0}")] - Io(#[from] std::io::Error), - #[error("not route found")] - NoRoute, - #[error("WMI {0}")] - Wmi(#[from] wmi::WMIError), + #[allow(dead_code)] // not sure why we have this here? + #[snafu(display("IO"))] + Io { source: std::io::Error }, + #[snafu(display("not route found"))] + NoRoute {}, + #[snafu(display("WMI"))] + Wmi { source: wmi::WMIError }, } fn get_default_route() -> Result { - let com_con = COMLibrary::new()?; - let wmi_con = WMIConnection::new(com_con)?; + let com_con = COMLibrary::new().context(WmiSnafu)?; + let wmi_con = WMIConnection::new(com_con).context(WmiSnafu)?; let query: HashMap<_, _> = [("Destination".into(), FilterValue::Str("0.0.0.0"))].into(); let route: Win32_IP4RouteTable = wmi_con - .filtered_query(&query)? + .filtered_query(&query) + .context(WmiSnafu)? .drain(..) .next() - .ok_or(Error::NoRoute)?; + .context(NoRouteSnafu)?; Ok(DefaultRouteDetails { interface_name: route.Name, diff --git a/netwatch/src/netmon.rs b/netwatch/src/netmon.rs index ab7031b..246fe2a 100644 --- a/netwatch/src/netmon.rs +++ b/netwatch/src/netmon.rs @@ -4,6 +4,8 @@ use n0_future::{ boxed::BoxFuture, task::{self, AbortOnDropHandle}, }; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::sync::{mpsc, oneshot}; mod actor; @@ -35,30 +37,34 @@ pub struct Monitor { actor_tx: mpsc::Sender, } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("channel closed")] - ChannelClosed, - #[error("actor {0}")] - Actor(#[from] actor::Error), + #[snafu(display("channel closed"))] + ChannelClosed {}, + #[snafu(display("actor error"))] + Actor { source: actor::Error }, } impl From> for Error { fn from(_value: mpsc::error::SendError) -> Self { - Self::ChannelClosed + ChannelClosedSnafu.build() } } impl From for Error { fn from(_value: oneshot::error::RecvError) -> Self { - Self::ChannelClosed + ChannelClosedSnafu.build() } } impl Monitor { /// Create a new monitor. pub async fn new() -> Result { - let actor = Actor::new().await?; + let actor = Actor::new().await.context(ActorSnafu)?; let actor_tx = actor.subscribe(); let handle = task::spawn(async move { diff --git a/netwatch/src/netmon/android.rs b/netwatch/src/netmon/android.rs index c9b48b3..14189bf 100644 --- a/netwatch/src/netmon/android.rs +++ b/netwatch/src/netmon/android.rs @@ -2,10 +2,12 @@ use tokio::sync::mpsc; use super::actor::NetworkMessage; -#[derive(Debug, thiserror::Error)] -#[error("error")] +#[derive(Debug, derive_more::Display)] +#[display("error")] pub struct Error; +impl std::error::Error for Error {} + #[derive(Debug)] pub(super) struct RouteMonitor { _sender: mpsc::Sender, diff --git a/netwatch/src/netmon/bsd.rs b/netwatch/src/netmon/bsd.rs index 21b9676..4e09921 100644 --- a/netwatch/src/netmon/bsd.rs +++ b/netwatch/src/netmon/bsd.rs @@ -1,5 +1,6 @@ #[cfg(any(target_os = "macos", target_os = "ios"))] use libc::{RTAX_DST, RTAX_IFP}; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::{io::AsyncReadExt, sync::mpsc}; use tokio_util::task::AbortOnDropHandle; use tracing::{trace, warn}; @@ -14,10 +15,14 @@ pub(super) struct RouteMonitor { _handle: AbortOnDropHandle<()>, } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("IO {0}")] - Io(#[from] std::io::Error), + #[snafu(display("IO"))] + Io { + source: std::io::Error, + backtrace: Option, + }, } fn create_socket() -> std::io::Result { @@ -33,7 +38,7 @@ fn create_socket() -> std::io::Result { impl RouteMonitor { pub(super) fn new(sender: mpsc::Sender) -> Result { - let mut socket = create_socket()?; + let mut socket = create_socket().context(IoSnafu)?; let handle = tokio::task::spawn(async move { trace!("AF_ROUTE monitor started"); diff --git a/netwatch/src/netmon/linux.rs b/netwatch/src/netmon/linux.rs index d460cec..944a6c1 100644 --- a/netwatch/src/netmon/linux.rs +++ b/netwatch/src/netmon/linux.rs @@ -12,6 +12,7 @@ use netlink_packet_core::NetlinkPayload; use netlink_packet_route::{address, route, RouteNetlinkMessage}; use netlink_sys::{AsyncSocket, SocketAddr}; use rtnetlink::new_connection; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::{sync::mpsc, task::JoinHandle}; use tracing::{trace, warn}; @@ -31,10 +32,14 @@ impl Drop for RouteMonitor { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("IO {0}")] - Io(#[from] std::io::Error), + #[snafu(display("IO"))] + Io { + source: std::io::Error, + backtrace: Option, + }, } const fn nl_mgrp(group: u32) -> u32 { @@ -58,7 +63,7 @@ macro_rules! get_nla { impl RouteMonitor { pub(super) fn new(sender: mpsc::Sender) -> Result { - let (mut conn, mut _handle, mut messages) = new_connection()?; + let (mut conn, mut _handle, mut messages) = new_connection().context(IoSnafu)?; // Specify flags to listen on. let groups = nl_mgrp(RTNLGRP_IPV4_IFADDR) @@ -69,7 +74,10 @@ impl RouteMonitor { | nl_mgrp(RTNLGRP_IPV6_RULE); let addr = SocketAddr::new(0, groups); - conn.socket_mut().socket_mut().bind(&addr)?; + conn.socket_mut() + .socket_mut() + .bind(&addr) + .context(IoSnafu)?; let conn_handle = tokio::task::spawn(conn); diff --git a/netwatch/src/netmon/wasm_browser.rs b/netwatch/src/netmon/wasm_browser.rs index dc2edfe..86da37e 100644 --- a/netwatch/src/netmon/wasm_browser.rs +++ b/netwatch/src/netmon/wasm_browser.rs @@ -8,10 +8,12 @@ use web_sys::{EventListener, EventTarget}; use super::actor::NetworkMessage; -#[derive(Debug, thiserror::Error)] -#[error("error")] +#[derive(Debug, derive_more::Display)] +#[display("error")] pub struct Error; +impl std::error::Error for Error {} + #[derive(Debug)] pub(super) struct RouteMonitor { _listeners: Option, diff --git a/netwatch/src/netmon/windows.rs b/netwatch/src/netmon/windows.rs index b2d06c1..5703774 100644 --- a/netwatch/src/netmon/windows.rs +++ b/netwatch/src/netmon/windows.rs @@ -1,6 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use libc::c_void; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::sync::mpsc; use tracing::{trace, warn}; use windows::Win32::{ @@ -18,12 +20,16 @@ pub(super) struct RouteMonitor { cb_handler: CallbackHandler, } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { - #[error("IO {0}")] - Io(#[from] std::io::Error), - #[error("win32: {0}")] - Win32(#[from] windows_result::Error), + #[snafu(display("IO"))] + Io { source: std::io::Error }, + #[snafu(display("win32"))] + Win32 { source: windows_result::Error }, } impl RouteMonitor { @@ -114,7 +120,8 @@ impl CallbackHandler { false, // initial notification, &mut handle, ) - .ok()?; + .ok() + .context(Win32Snafu)?; } self.unicast_callbacks.insert(handle.0 as isize, cb); @@ -134,7 +141,8 @@ impl CallbackHandler { { unsafe { windows::Win32::NetworkManagement::IpHelper::CancelMibChangeNotify2(handle.0) - .ok()?; + .ok() + .context(Win32Snafu)?; } } @@ -156,7 +164,8 @@ impl CallbackHandler { false, // initial notification, &mut handle, ) - .ok()?; + .ok() + .context(Win32Snafu)?; } self.route_callbacks.insert(handle.0 as isize, cb); @@ -176,7 +185,8 @@ impl CallbackHandler { { unsafe { windows::Win32::NetworkManagement::IpHelper::CancelMibChangeNotify2(handle.0) - .ok()?; + .ok() + .context(Win32Snafu)?; } } diff --git a/portmapper/Cargo.toml b/portmapper/Cargo.toml index 6d1d4d4..d06f462 100644 --- a/portmapper/Cargo.toml +++ b/portmapper/Cargo.toml @@ -24,13 +24,14 @@ futures-util = "0.3.25" igd-next = { version = "0.15.1", features = ["aio_tokio"] } iroh-metrics = { version = "0.32", default-features = false } libc = "0.2.139" +nested_enum_utils = "0.2.0" netwatch = { version = "0.4.0", path = "../netwatch" } num_enum = "0.7" rand = "0.8" serde = { version = "1", features = ["derive", "rc"] } smallvec = "1.11.1" +snafu = { version = "0.8.5", features = ["rust_1_81"] } socket2 = "0.5.3" -thiserror = "2" time = "0.3.20" tokio = { version = "1", features = ["io-util", "macros", "sync", "rt", "net", "fs", "io-std", "signal", "process"] } tokio-util = { version = "0.7", features = ["io-util", "io", "codec", "rt"] } diff --git a/portmapper/src/lib.rs b/portmapper/src/lib.rs index 767f2c3..85c7a72 100644 --- a/portmapper/src/lib.rs +++ b/portmapper/src/lib.rs @@ -10,6 +10,7 @@ use current_mapping::CurrentMapping; use futures_lite::StreamExt; use iroh_metrics::inc; use netwatch::interfaces::HomeRouter; +use snafu::Snafu; use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::task::AbortOnDropHandle; use tracing::{debug, info_span, trace, Instrument}; @@ -66,29 +67,26 @@ impl ProbeOutput { } } -#[derive(Debug, thiserror::Error, Clone)] +// Cannot have backtrace due to Clone bound +// #[nested_enum_utils::common_fields({ +// backtrace: Option, +// })] +#[allow(missing_docs)] +#[derive(Debug, Clone, Snafu)] +#[non_exhaustive] pub enum ProbeError { - #[error("Mapping channel is full")] + #[snafu(display("Mapping channel is full"))] ChannelFull, - #[error("Mapping channel is closed")] + #[snafu(display("Mapping channel is closed"))] ChannelClosed, - #[error("No gateway found for probe")] + #[snafu(display("No gateway found for probe"))] NoGateway, - #[error("gateway found is ipv6, ignoring")] + #[snafu(display("gateway found is ipv6, ignoring"))] Ipv6Gateway, - #[error("Join is_panic: {is_panic}, is_cancelled: {is_cancelled}")] + #[snafu(display("Probe task stopped. is_panic: {is_panic}, is_cancelled: {is_cancelled}"))] Join { is_panic: bool, is_cancelled: bool }, } -impl From for ProbeError { - fn from(value: tokio::task::JoinError) -> Self { - Self::Join { - is_panic: value.is_panic(), - is_cancelled: value.is_cancelled(), - } - } -} - #[derive(derive_more::Debug)] enum Message { /// Attempt to get a mapping if the local port is set but there is no mapping. @@ -180,8 +178,8 @@ impl Client { // recover the sender and return the error there let (result_tx, e) = match e { - Full(Message::Probe { result_tx }) => (result_tx, ProbeError::ChannelFull), - Closed(Message::Probe { result_tx }) => (result_tx, ProbeError::ChannelClosed), + Full(Message::Probe { result_tx }) => (result_tx, ChannelFullSnafu.build()), + Closed(Message::Probe { result_tx }) => (result_tx, ChannelClosedSnafu.build()), Full(_) | Closed(_) => unreachable!("Sent value is a probe."), }; @@ -496,7 +494,7 @@ impl Service { trace!("tick: probe ready"); // retrieve the receivers and clear the task let receivers = self.probing_task.take().expect("is some").1; - let probe_result = probe_result.map_err(Into::into); + let probe_result = probe_result.map_err(|e| JoinSnafu { is_panic: e.is_panic(), is_cancelled: e.is_cancelled() }.build()); self.on_probe_result(probe_result, receivers); } Some(event) = self.current_mapping.next() => { @@ -714,7 +712,7 @@ impl Service { /// Gets the local ip and gateway address for port mapping. fn ip_and_gateway() -> Result<(Ipv4Addr, Ipv4Addr), ProbeError> { let Some(HomeRouter { gateway, my_ip }) = HomeRouter::new() else { - return Err(ProbeError::NoGateway); + return Err(NoGatewaySnafu.build()); }; let local_ip = match my_ip { @@ -730,7 +728,7 @@ fn ip_and_gateway() -> Result<(Ipv4Addr, Ipv4Addr), ProbeError> { }; let std::net::IpAddr::V4(gateway) = gateway else { - return Err(ProbeError::Ipv6Gateway); + return Err(Ipv6GatewaySnafu.build()); }; Ok((local_ip, gateway)) diff --git a/portmapper/src/mapping.rs b/portmapper/src/mapping.rs index 7ced912..e90c4e6 100644 --- a/portmapper/src/mapping.rs +++ b/portmapper/src/mapping.rs @@ -2,6 +2,9 @@ use std::{net::Ipv4Addr, num::NonZeroU16, time::Duration}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; + use super::{nat_pmp, pcp, upnp}; pub(super) trait PortMapped: std::fmt::Debug + Unpin { @@ -22,15 +25,19 @@ pub enum Mapping { } /// Mapping error. -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] #[non_exhaustive] pub enum Error { - #[error("PCP mapping failed: {0}")] - Pcp(#[from] pcp::Error), - #[error("NAT-PMP mapping failed: {0}")] - NatPmp(#[from] nat_pmp::Error), - #[error("UPnP mapping failed: {0}")] - Upnp(#[from] upnp::Error), + #[snafu(display("PCP mapping failed"))] + Pcp { source: pcp::Error }, + #[snafu(display("NAT-PMP mapping failed"))] + NatPmp { source: nat_pmp::Error }, + #[snafu(display("UPnP mapping failed"))] + Upnp { source: upnp::Error }, } impl Mapping { @@ -44,7 +51,7 @@ impl Mapping { pcp::Mapping::new(local_ip, local_port, gateway, external_addr) .await .map(Self::Pcp) - .map_err(Into::into) + .context(PcpSnafu) } /// Create a new NAT-PMP mapping. @@ -62,7 +69,7 @@ impl Mapping { ) .await .map(Self::NatPmp) - .map_err(Into::into) + .context(NatPmpSnafu) } /// Create a new UPnP mapping. @@ -75,15 +82,15 @@ impl Mapping { upnp::Mapping::new(local_ip, local_port, gateway, external_port) .await .map(Self::Upnp) - .map_err(Into::into) + .context(UpnpSnafu) } /// Release the mapping. pub(crate) async fn release(self) -> Result<(), Error> { match self { - Mapping::Upnp(m) => m.release().await?, - Mapping::Pcp(m) => m.release().await?, - Mapping::NatPmp(m) => m.release().await?, + Mapping::Upnp(m) => m.release().await.context(UpnpSnafu)?, + Mapping::Pcp(m) => m.release().await.context(PcpSnafu)?, + Mapping::NatPmp(m) => m.release().await.context(NatPmpSnafu)?, } Ok(()) } diff --git a/portmapper/src/nat_pmp.rs b/portmapper/src/nat_pmp.rs index 48813b5..a1dd6a2 100644 --- a/portmapper/src/nat_pmp.rs +++ b/portmapper/src/nat_pmp.rs @@ -2,7 +2,9 @@ use std::{net::Ipv4Addr, num::NonZeroU16, time::Duration}; +use nested_enum_utils::common_fields; use netwatch::UdpSocket; +use snafu::{Backtrace, Snafu}; use tracing::{debug, trace}; use self::protocol::{MapProtocol, Request, Response}; @@ -31,17 +33,21 @@ pub struct Mapping { lifetime_seconds: u32, } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] #[non_exhaustive] pub enum Error { - #[error("server returned unexpected response for mapping request")] - UnexpectedServerResponse, - #[error("received 0 port from server as external port")] - ZeroExternalPort, - #[error("IO: {0}")] - Io(#[from] std::io::Error), - #[error("Protocol: {0}")] - Protocol(#[from] protocol::Error), + #[snafu(display("server returned unexpected response for mapping request"))] + UnexpectedServerResponse {}, + #[snafu(display("received 0 port from server as external port"))] + ZeroExternalPort {}, + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(transparent)] + Protocol { source: protocol::Error }, } impl super::mapping::PortMapped for Mapping { @@ -92,12 +98,12 @@ impl Mapping { external_port, lifetime_seconds, } if private_port == Into::::into(local_port) => (external_port, lifetime_seconds), - _ => return Err(Error::UnexpectedServerResponse), + _ => return Err(UnexpectedServerResponseSnafu.build()), }; let external_port = external_port .try_into() - .map_err(|_| Error::ZeroExternalPort)?; + .map_err(|_| ZeroExternalPortSnafu.build())?; // now send the second request to get the external address let req = Request::ExternalAddress; @@ -117,7 +123,7 @@ impl Mapping { epoch_time: _, public_ip, } => public_ip, - _ => return Err(Error::UnexpectedServerResponse), + _ => return Err(UnexpectedServerResponseSnafu.build()), }; Ok(Mapping { diff --git a/portmapper/src/nat_pmp/protocol/response.rs b/portmapper/src/nat_pmp/protocol/response.rs index 940a465..4e0b8bf 100644 --- a/portmapper/src/nat_pmp/protocol/response.rs +++ b/portmapper/src/nat_pmp/protocol/response.rs @@ -2,7 +2,9 @@ use std::net::Ipv4Addr; +use nested_enum_utils::common_fields; use num_enum::{IntoPrimitive, TryFromPrimitive}; +use snafu::{Backtrace, Snafu}; use super::{MapProtocol, Opcode, Version}; @@ -51,38 +53,43 @@ pub enum ResultCode { } /// Errors that can occur when decoding a [`Response`] from a server. -#[derive(Debug, derive_more::Display, thiserror::Error, PartialEq, Eq)] +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum Error { /// Request is too short or is otherwise malformed. - #[display("Response is malformed")] - Malformed, + #[snafu(display("Response is malformed"))] + Malformed {}, /// The [`Response::RESPONSE_INDICATOR`] is not present. - #[display("Packet does not appear to be a response")] - NotAResponse, + #[snafu(display("Packet does not appear to be a response"))] + NotAResponse {}, /// The received opcode is not recognized. - #[display("Invalid Opcode received")] - InvalidOpcode, + #[snafu(display("Invalid Opcode received"))] + InvalidOpcode {}, /// The received version is not recognized. - #[display("Invalid version received")] - InvalidVersion, + #[snafu(display("Invalid version received"))] + InvalidVersion {}, /// The received result code is not recognized. - #[display("Invalid result code received")] - InvalidResultCode, + #[snafu(display("Invalid result code received"))] + InvalidResultCode {}, /// Received an error code indicating the server does not support the sent version. - #[display("Server does not support the version")] - UnsupportedVersion, + #[snafu(display("Server does not support the version"))] + UnsupportedVersion {}, /// Received an error code indicating the operation is supported but not authorized. - #[display("Operation is supported but not authorized")] - NotAuthorizedOrRefused, + #[snafu(display("Operation is supported but not authorized"))] + NotAuthorizedOrRefused {}, /// Received an error code indicating the server experienced a network failure - #[display("Server experienced a network failure")] - NetworkFailure, + #[snafu(display("Server experienced a network failure"))] + NetworkFailure {}, /// Received an error code indicating the server cannot create more mappings at this time. - #[display("Server is out of resources")] - OutOfResources, + #[snafu(display("Server is out of resources"))] + OutOfResources {}, /// Received an error code indicating the Opcode is not supported by the server. - #[display("Server does not support this opcode")] - UnsupportedOpcode, + #[snafu(display("Server does not support this opcode"))] + UnsupportedOpcode {}, } impl Response { @@ -110,30 +117,30 @@ impl Response { /// Decode a response. pub fn decode(buf: &[u8]) -> Result { if buf.len() < Self::MIN_SIZE || buf.len() > Self::MAX_SIZE { - return Err(Error::Malformed); + return Err(MalformedSnafu.build()); } - let _: Version = buf[0].try_into().map_err(|_| Error::InvalidVersion)?; + let _: Version = buf[0].try_into().map_err(|_| InvalidVersionSnafu.build())?; let opcode = buf[1]; if opcode & Self::RESPONSE_INDICATOR != Self::RESPONSE_INDICATOR { - return Err(Error::NotAResponse); + return Err(NotAResponseSnafu.build()); } let opcode: Opcode = (opcode & !Self::RESPONSE_INDICATOR) .try_into() - .map_err(|_| Error::InvalidOpcode)?; + .map_err(|_| InvalidOpcodeSnafu.build())?; let result_bytes = u16::from_be_bytes(buf[2..4].try_into().expect("slice has the right len")); let result_code = result_bytes .try_into() - .map_err(|_| Error::InvalidResultCode)?; + .map_err(|_| InvalidResultCodeSnafu.build())?; match result_code { ResultCode::Success => Ok(()), - ResultCode::UnsupportedVersion => Err(Error::UnsupportedVersion), - ResultCode::NotAuthorizedOrRefused => Err(Error::NotAuthorizedOrRefused), - ResultCode::NetworkFailure => Err(Error::NetworkFailure), - ResultCode::OutOfResources => Err(Error::OutOfResources), - ResultCode::UnsupportedOpcode => Err(Error::UnsupportedOpcode), + ResultCode::UnsupportedVersion => Err(UnsupportedVersionSnafu.build()), + ResultCode::NotAuthorizedOrRefused => Err(NotAuthorizedOrRefusedSnafu.build()), + ResultCode::NetworkFailure => Err(NetworkFailureSnafu.build()), + ResultCode::OutOfResources => Err(OutOfResourcesSnafu.build()), + ResultCode::UnsupportedOpcode => Err(UnsupportedOpcodeSnafu.build()), }?; let response = match opcode { @@ -273,7 +280,7 @@ mod tests { let response = Response::random(Opcode::DetermineExternalAddress, &mut gen); let encoded = response.encode(); - assert_eq!(Ok(response), Response::decode(&encoded)); + assert_eq!(response, Response::decode(&encoded).unwrap()); } #[test] @@ -282,6 +289,6 @@ mod tests { let response = Response::random(Opcode::MapUdp, &mut rng); let encoded = response.encode(); - assert_eq!(Ok(response), Response::decode(&encoded)); + assert_eq!(response, Response::decode(&encoded).unwrap()); } } diff --git a/portmapper/src/pcp.rs b/portmapper/src/pcp.rs index 5fb9d05..b4d1daf 100644 --- a/portmapper/src/pcp.rs +++ b/portmapper/src/pcp.rs @@ -2,8 +2,10 @@ use std::{net::Ipv4Addr, num::NonZeroU16, time::Duration}; +use nested_enum_utils::common_fields; use netwatch::UdpSocket; use rand::RngCore; +use snafu::{Backtrace, ResultExt, Snafu}; use tracing::{debug, trace}; use crate::defaults::PCP_RECV_TIMEOUT as RECV_TIMEOUT; @@ -34,27 +36,31 @@ pub struct Mapping { nonce: [u8; 12], } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] #[non_exhaustive] pub enum Error { - #[error("received nonce does not match sent request")] - NonceMissmatch, - #[error("received mapping is not for UDP")] - ProtocolMissmatch, - #[error("received mapping is for a local port that does not match the requested one")] - PortMissmatch, - #[error("received 0 external port for mapping")] - ZeroExternalPort, - #[error("received external address is not ipv4")] - NotIpv4, - #[error("received an announce response for a map request")] - InvalidAnnounce, - #[error("IO: {0}")] - Io(#[from] std::io::Error), - #[error("Protocol: {0}")] - Protocol(#[from] protocol::Error), - #[error("Protocol Decode: {0}")] - ProtocolDecode(#[from] protocol::DecodeError), + #[snafu(display("received nonce does not match sent request"))] + NonceMissmatch {}, + #[snafu(display("received mapping is not for UDP"))] + ProtocolMissmatch {}, + #[snafu(display( + "received mapping is for a local port that does not match the requested one" + ))] + PortMissmatch {}, + #[snafu(display("received 0 external port for mapping"))] + ZeroExternalPort {}, + #[snafu(display("received external address is not ipv4"))] + NotIpv4 {}, + #[snafu(display("received an announce response for a map request"))] + InvalidAnnounce {}, + #[snafu(display("IO error during PCP"))] + Io { source: std::io::Error }, + #[snafu(display("Protocol error during PCP"))] + Protocol { source: protocol::Error }, } impl super::mapping::PortMapped for Mapping { @@ -76,8 +82,10 @@ impl Mapping { preferred_external_address: Option<(Ipv4Addr, NonZeroU16)>, ) -> Result { // create the socket and send the request - let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT).into())?; + let socket = UdpSocket::bind_full((local_ip, 0)).context(IoSnafu)?; + socket + .connect((gateway, protocol::SERVER_PORT).into()) + .context(IoSnafu)?; let mut nonce = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce); @@ -96,7 +104,7 @@ impl Mapping { MAPPING_REQUESTED_LIFETIME_SECONDS, ); - socket.send(&req.encode()).await?; + socket.send(&req.encode()).await.context(IoSnafu)?; // wait for the response and decode it let mut buffer = vec![0; protocol::Response::MAX_SIZE]; @@ -104,8 +112,10 @@ impl Mapping { .await .map_err(|_| { std::io::Error::new(std::io::ErrorKind::TimedOut, "read timeout".to_string()) - })??; - let response = protocol::Response::decode(&buffer[..read])?; + }) + .context(IoSnafu)? + .context(IoSnafu)?; + let response = protocol::Response::decode(&buffer[..read]).context(ProtocolSnafu)?; // verify that the response is correct and matches the request let protocol::Response { @@ -125,22 +135,24 @@ impl Mapping { } = map_data; if nonce != received_nonce { - return Err(Error::NonceMissmatch); + return Err(NonceMissmatchSnafu.build()); } if protocol != protocol::MapProtocol::Udp { - return Err(Error::ProtocolMissmatch); + return Err(ProtocolMissmatchSnafu.build()); } let sent_port: u16 = local_port.into(); if received_local_port != sent_port { - return Err(Error::PortMissmatch); + return Err(PortMissmatchSnafu.build()); } let external_port = external_port .try_into() - .map_err(|_| Error::ZeroExternalPort)?; + .map_err(|_| ZeroExternalPortSnafu.build())?; - let external_address = external_address.to_ipv4_mapped().ok_or(Error::NotIpv4)?; + let external_address = external_address + .to_ipv4_mapped() + .ok_or(NotIpv4Snafu.build())?; Ok(Mapping { external_port, @@ -152,7 +164,7 @@ impl Mapping { gateway, }) } - protocol::OpcodeData::Announce => Err(Error::InvalidAnnounce), + protocol::OpcodeData::Announce => Err(InvalidAnnounceSnafu.build()), } } @@ -166,13 +178,15 @@ impl Mapping { } = self; // create the socket and send the request - let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT).into())?; + let socket = UdpSocket::bind_full((local_ip, 0)).context(IoSnafu)?; + socket + .connect((gateway, protocol::SERVER_PORT).into()) + .context(IoSnafu)?; let local_port = local_port.into(); let req = protocol::Request::mapping(nonce, local_port, local_ip, None, None, 0); - socket.send(&req.encode()).await?; + socket.send(&req.encode()).await.context(IoSnafu)?; // mapping deletion is a notification, no point in waiting for the response Ok(()) @@ -210,19 +224,21 @@ async fn probe_available_fallible( gateway: Ipv4Addr, ) -> Result { // create the socket and send the request - let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT).into())?; + let socket = UdpSocket::bind_full((local_ip, 0)).context(IoSnafu)?; + socket + .connect((gateway, protocol::SERVER_PORT).into()) + .context(IoSnafu)?; let req = protocol::Request::announce(local_ip.to_ipv6_mapped()); - socket.send(&req.encode()).await?; + socket.send(&req.encode()).await.context(IoSnafu)?; // wait for the response and decode it let mut buffer = vec![0; protocol::Response::MAX_SIZE]; let read = tokio::time::timeout(RECV_TIMEOUT, socket.recv(&mut buffer)) .await - .map_err(|_| { - std::io::Error::new(std::io::ErrorKind::TimedOut, "read timeout".to_string()) - })??; - let response = protocol::Response::decode(&buffer[..read])?; + .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "read timeout".to_string())) + .context(IoSnafu)? + .context(IoSnafu)?; + let response = protocol::Response::decode(&buffer[..read]).context(ProtocolSnafu)?; Ok(response) } diff --git a/portmapper/src/pcp/protocol/response.rs b/portmapper/src/pcp/protocol/response.rs index 6469b22..a0236be 100644 --- a/portmapper/src/pcp/protocol/response.rs +++ b/portmapper/src/pcp/protocol/response.rs @@ -1,7 +1,9 @@ //! A PCP response encoding and decoding. use derive_more::Display; +use nested_enum_utils::common_fields; use num_enum::{IntoPrimitive, TryFromPrimitive, TryFromPrimitiveError}; +use snafu::{Backtrace, Snafu}; use super::{opcode_data::OpcodeData, Opcode, Version}; @@ -18,9 +20,7 @@ pub enum SuccessCode { /// /// Refer to [RFC 6887 Result Codes](https://datatracker.ietf.org/doc/html/rfc6887#section-7.4) // NOTE: docs for each variant are largely adapted from the RFC's description of each code. -#[derive( - Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive, Display, thiserror::Error, -)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive, Display)] #[repr(u8)] pub enum ErrorCode { /// The version number at the start of the PCP Request header is not recognized by the PCP @@ -72,6 +72,8 @@ pub enum ErrorCode { ExcessiveRemotePeers = 13, } +impl std::error::Error for ErrorCode {} + /// Result code of a PCP response. #[derive(Debug)] pub enum ResultCode { @@ -120,32 +122,39 @@ pub struct Response { } /// Errors that can occur when decoding a [`Response`] from a server. -#[derive(Debug, derive_more::Display, thiserror::Error, PartialEq, Eq)] +#[allow(missing_docs)] +#[non_exhaustive] +#[common_fields({ + backtrace: Option, +})] +#[derive(Debug, Snafu)] pub enum DecodeError { /// Request is too short or is otherwise malformed. - #[display("Response is malformed")] - Malformed, + #[snafu(display("Response is malformed"))] + Malformed {}, /// The [`Response::RESPONSE_INDICATOR`] is not present. - #[display("Packet does not appear to be a response")] - NotAResponse, + #[snafu(display("Packet does not appear to be a response"))] + NotAResponse {}, /// The received opcode is not recognized. - #[display("Invalid Opcode received")] - InvalidOpcode, + #[snafu(display("Invalid Opcode received"))] + InvalidOpcode {}, /// The received version is not recognized. - #[display("Invalid version received")] - InvalidVersion, + #[snafu(display("Invalid version received"))] + InvalidVersion {}, /// The received result code is not recognized. - #[display("Invalid result code received")] - InvalidResultCode, + #[snafu(display("Invalid result code received"))] + InvalidResultCode {}, /// The received opcode data could not be decoded. - #[display("Invalid opcode data received")] - InvalidOpcodeData, + #[snafu(display("Invalid opcode data received"))] + InvalidOpcodeData {}, } -#[derive(Debug, derive_more::Display, thiserror::Error, PartialEq, Eq)] +#[derive(Debug, Snafu)] pub enum Error { - DecodeError(DecodeError), - ErrorCode(ErrorCode), + #[snafu(transparent)] + DecodeError { source: DecodeError }, + #[snafu(transparent)] + ErrorCode { source: ErrorCode }, } impl Response { @@ -168,31 +177,31 @@ impl Response { /// Decode a response. pub fn decode(buf: &[u8]) -> Result { - if buf.len() < Self::MIN_SIZE || buf.len() > Self::MAX_SIZE { - return Err(Error::DecodeError(DecodeError::Malformed)); - } + snafu::ensure!( + Self::MIN_SIZE <= buf.len() && buf.len() <= Self::MAX_SIZE, + MalformedSnafu + ); - let _version: Version = buf[0] - .try_into() - .map_err(|_| Error::DecodeError(DecodeError::InvalidVersion))?; + let _version: Version = buf[0].try_into().map_err(|_| InvalidVersionSnafu.build())?; let opcode = buf[1]; - if opcode & Self::RESPONSE_INDICATOR != Self::RESPONSE_INDICATOR { - return Err(Error::DecodeError(DecodeError::NotAResponse)); - } + snafu::ensure!( + opcode & Self::RESPONSE_INDICATOR == Self::RESPONSE_INDICATOR, + NotAResponseSnafu + ); let opcode: Opcode = (opcode & !Self::RESPONSE_INDICATOR) .try_into() - .map_err(|_| Error::DecodeError(DecodeError::InvalidOpcode))?; + .map_err(|_| InvalidOpcodeSnafu.build())?; // buf[2] reserved // return early if the result code is an error let result_code: ResultCode = buf[3] .try_into() - .map_err(|_| Error::DecodeError(DecodeError::InvalidResultCode))?; + .map_err(|_| InvalidResultCodeSnafu.build())?; match result_code { ResultCode::Success => {} - ResultCode::Error(error_code) => return Err(Error::ErrorCode(error_code)), + ResultCode::Error(error_code) => return Err(error_code.into()), } let lifetime_bytes = buf[4..8].try_into().expect("slice has the right len"); @@ -203,8 +212,8 @@ impl Response { // buf[12..24] reserved - let data = OpcodeData::decode(opcode, &buf[24..]) - .map_err(|_| Error::DecodeError(DecodeError::InvalidOpcodeData))?; + let data = + OpcodeData::decode(opcode, &buf[24..]).map_err(|_| InvalidOpcodeDataSnafu.build())?; Ok(Response { lifetime_seconds, @@ -269,7 +278,7 @@ mod tests { let response = Response::random(Opcode::Announce, &mut gen); let encoded = response.encode(); - assert_eq!(Ok(response), Response::decode(&encoded)); + assert_eq!(response, Response::decode(&encoded).unwrap()); } #[test] @@ -290,6 +299,6 @@ mod tests { let response = Response::random(Opcode::Map, &mut rng); let encoded = response.encode(); - assert_eq!(Ok(response), Response::decode(&encoded)); + assert_eq!(response, Response::decode(&encoded).unwrap()); } } diff --git a/portmapper/src/upnp.rs b/portmapper/src/upnp.rs index 9c6a87e..dd8bc3d 100644 --- a/portmapper/src/upnp.rs +++ b/portmapper/src/upnp.rs @@ -6,6 +6,8 @@ use std::{ use igd_next::{aio as aigd, AddAnyPortError, GetExternalIpError, RemovePortError, SearchError}; use iroh_metrics::inc; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; use tracing::debug; use super::Metrics; @@ -35,23 +37,27 @@ pub struct Mapping { external_port: NonZeroU16, } -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] #[non_exhaustive] pub enum Error { - #[error("Zero external port")] - ZeroExternalPort, - #[error("igd device's external ip is ipv6")] - NotIpv4, - #[error("Remove Port {0}")] - RemovePort(#[from] RemovePortError), - #[error("Search {0}")] - Search(#[from] SearchError), - #[error("Get external IP {0}")] - GetExternalIp(#[from] GetExternalIpError), - #[error("Add any port {0}")] - AddAnyPort(#[from] AddAnyPortError), - #[error("IO {0}")] - Io(#[from] std::io::Error), + #[snafu(display("Zero external port"))] + ZeroExternalPort {}, + #[snafu(display("igd device's external ip is ipv6"))] + NotIpv4 {}, + #[snafu(display("Remove Port"))] + RemovePort { source: RemovePortError }, + #[snafu(display("Search"))] + Search { source: SearchError }, + #[snafu(display("Get external IP"))] + GetExternalIp { source: GetExternalIpError }, + #[snafu(display("Add any port"))] + AddAnyPort { source: AddAnyPortError }, + #[snafu(display("IO"))] + Io { source: std::io::Error }, } impl Mapping { @@ -78,11 +84,17 @@ impl Mapping { .await .map_err(|_| { std::io::Error::new(std::io::ErrorKind::TimedOut, "read timeout".to_string()) - })?? + }) + .context(IoSnafu)? + .context(SearchSnafu)? }; - let std::net::IpAddr::V4(external_ip) = gateway.get_external_ip().await? else { - return Err(Error::NotIpv4); + let std::net::IpAddr::V4(external_ip) = gateway + .get_external_ip() + .await + .context(GetExternalIpSnafu)? + else { + return Err(NotIpv4Snafu.build()); }; // if we are trying to get a specific external port, try this first. If this fails, default @@ -114,9 +126,10 @@ impl Mapping { PORT_MAPPING_LEASE_DURATION_SECONDS, PORT_MAPPING_DESCRIPTION, ) - .await? + .await + .context(AddAnyPortSnafu)? .try_into() - .map_err(|_| Error::ZeroExternalPort)?; + .map_err(|_| ZeroExternalPortSnafu.build())?; Ok(Mapping { gateway, @@ -138,7 +151,8 @@ impl Mapping { } = self; gateway .remove_port(igd_next::PortMappingProtocol::UDP, external_port.into()) - .await?; + .await + .context(RemovePortSnafu)?; Ok(()) }