Skip to content

Commit 3df0eab

Browse files
RUST-229 Parse IPv6 addresses in the connection string (#1242)
1 parent 777b98c commit 3df0eab

File tree

2 files changed

+143
-119
lines changed

2 files changed

+143
-119
lines changed

src/client/options.rs

Lines changed: 98 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::{
1111
convert::TryFrom,
1212
fmt::{self, Display, Formatter, Write},
1313
hash::{Hash, Hasher},
14+
net::Ipv6Addr,
1415
path::PathBuf,
1516
str::FromStr,
1617
time::Duration,
@@ -128,9 +129,29 @@ impl<'de> Deserialize<'de> for ServerAddress {
128129
where
129130
D: Deserializer<'de>,
130131
{
131-
let s: String = Deserialize::deserialize(deserializer)?;
132-
Self::parse(s.as_str())
133-
.map_err(|e| <D::Error as serde::de::Error>::custom(format!("{}", e)))
132+
#[derive(Deserialize)]
133+
#[serde(untagged)]
134+
enum ServerAddressHelper {
135+
String(String),
136+
Object { host: String, port: Option<u16> },
137+
}
138+
139+
let helper = ServerAddressHelper::deserialize(deserializer)?;
140+
match helper {
141+
ServerAddressHelper::String(string) => {
142+
Self::parse(string).map_err(serde::de::Error::custom)
143+
}
144+
ServerAddressHelper::Object { host, port } => {
145+
#[cfg(unix)]
146+
if host.ends_with("sock") {
147+
return Ok(Self::Unix {
148+
path: PathBuf::from(host),
149+
});
150+
}
151+
152+
Ok(Self::Tcp { host, port })
153+
}
154+
}
134155
}
135156
}
136157

@@ -185,74 +206,95 @@ impl FromStr for ServerAddress {
185206
}
186207

187208
impl ServerAddress {
188-
/// Parses an address string into a `ServerAddress`.
209+
/// Parses an address string into a [`ServerAddress`].
189210
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
190211
let address = address.as_ref();
191-
// checks if the address is a unix domain socket
192-
#[cfg(unix)]
193-
{
194-
if address.ends_with(".sock") {
195-
return Ok(ServerAddress::Unix {
212+
213+
if address.ends_with(".sock") {
214+
#[cfg(unix)]
215+
{
216+
let address = percent_decode(address, "unix domain sockets must be URL-encoded")?;
217+
return Ok(Self::Unix {
196218
path: PathBuf::from(address),
197219
});
198220
}
221+
#[cfg(not(unix))]
222+
return Err(ErrorKind::InvalidArgument {
223+
message: "unix domain sockets are not supported on this platform".to_string(),
224+
}
225+
.into());
199226
}
200-
let mut parts = address.split(':');
201-
let hostname = match parts.next() {
202-
Some(part) => {
203-
if part.is_empty() {
204-
return Err(ErrorKind::InvalidArgument {
205-
message: format!(
206-
"invalid server address: \"{}\"; hostname cannot be empty",
207-
address
208-
),
209-
}
210-
.into());
227+
228+
let (hostname, port) = if let Some(ip_literal) = address.strip_prefix("[") {
229+
let Some((hostname, port)) = ip_literal.split_once("]") else {
230+
return Err(ErrorKind::InvalidArgument {
231+
message: format!(
232+
"invalid server address {}: missing closing ']' in IP literal hostname",
233+
address
234+
),
211235
}
212-
part
213-
}
214-
None => {
236+
.into());
237+
};
238+
239+
if let Err(parse_error) = Ipv6Addr::from_str(hostname) {
215240
return Err(ErrorKind::InvalidArgument {
216-
message: format!("invalid server address: \"{}\"", address),
241+
message: format!("invalid server address {}: {}", address, parse_error),
217242
}
218-
.into())
243+
.into());
219244
}
220-
};
221245

222-
let port = match parts.next() {
223-
Some(part) => {
224-
let port = u16::from_str(part).map_err(|_| ErrorKind::InvalidArgument {
246+
let port = if port.is_empty() {
247+
None
248+
} else if let Some(port) = port.strip_prefix(":") {
249+
Some(port)
250+
} else {
251+
return Err(ErrorKind::InvalidArgument {
225252
message: format!(
226-
"port must be valid 16-bit unsigned integer, instead got: {}",
227-
part
253+
"invalid server address {}: the hostname can only be followed by a port \
254+
prefixed with ':', got {}",
255+
address, port
228256
),
229-
})?;
230-
231-
if port == 0 {
232-
return Err(ErrorKind::InvalidArgument {
233-
message: format!(
234-
"invalid server address: \"{}\"; port must be non-zero",
235-
address
236-
),
237-
}
238-
.into());
239257
}
240-
if parts.next().is_some() {
258+
.into());
259+
};
260+
261+
(hostname, port)
262+
} else {
263+
match address.split_once(":") {
264+
Some((hostname, port)) => (hostname, Some(port)),
265+
None => (address, None),
266+
}
267+
};
268+
269+
if hostname.is_empty() {
270+
return Err(ErrorKind::InvalidArgument {
271+
message: format!(
272+
"invalid server address {}: the hostname cannot be empty",
273+
address
274+
),
275+
}
276+
.into());
277+
}
278+
279+
let port = if let Some(port) = port {
280+
match u16::from_str(port) {
281+
Ok(0) | Err(_) => {
241282
return Err(ErrorKind::InvalidArgument {
242283
message: format!(
243-
"address \"{}\" contains more than one unescaped ':'",
244-
address
284+
"invalid server address {}: the port must be an integer between 1 and \
285+
65535, got {}",
286+
address, port
245287
),
246288
}
247-
.into());
289+
.into())
248290
}
249-
250-
Some(port)
291+
Ok(port) => Some(port),
251292
}
252-
None => None,
293+
} else {
294+
None
253295
};
254296

255-
Ok(ServerAddress::Tcp {
297+
Ok(Self::Tcp {
256298
host: hostname.to_lowercase(),
257299
port,
258300
})
@@ -1165,6 +1207,7 @@ impl ClientOptions {
11651207
.iter()
11661208
.filter_map(|addr| match addr {
11671209
ServerAddress::Tcp { host, .. } => Some(host.to_ascii_lowercase()),
1210+
#[cfg(unix)]
11681211
_ => None,
11691212
})
11701213
.collect()
@@ -1440,31 +1483,15 @@ impl ConnectionString {
14401483
None => (None, None),
14411484
};
14421485

1443-
let mut host_list = Vec::with_capacity(hosts_section.len());
1444-
for host in hosts_section.split(',') {
1445-
let address = if host.ends_with(".sock") {
1446-
#[cfg(unix)]
1447-
{
1448-
ServerAddress::parse(percent_decode(
1449-
host,
1450-
"Unix domain sockets must be URL-encoded",
1451-
)?)
1452-
}
1453-
#[cfg(not(unix))]
1454-
return Err(ErrorKind::InvalidArgument {
1455-
message: "Unix domain sockets are not supported on this platform".to_string(),
1456-
}
1457-
.into());
1458-
} else {
1459-
ServerAddress::parse(host)
1460-
}?;
1461-
host_list.push(address);
1462-
}
1486+
let hosts = hosts_section
1487+
.split(',')
1488+
.map(ServerAddress::parse)
1489+
.collect::<Result<Vec<ServerAddress>>>()?;
14631490

14641491
let host_info = if !srv {
1465-
HostInfo::HostIdentifiers(host_list)
1492+
HostInfo::HostIdentifiers(hosts)
14661493
} else {
1467-
match &host_list[..] {
1494+
match &hosts[..] {
14681495
[ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()),
14691496
[ServerAddress::Tcp {
14701497
host: _,

src/client/options/test.rs

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
bson::{Bson, Document},
1010
bson_util::get_int,
1111
client::options::{ClientOptions, ConnectionString, ServerAddress},
12-
error::{Error, ErrorKind, Result},
12+
error::ErrorKind,
1313
test::spec::deserialize_spec_tests,
1414
Client,
1515
};
@@ -22,13 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
2222
"maxPoolSize=0 does not error",
2323
// TODO RUST-226: unskip this test
2424
"Valid tlsCertificateKeyFilePassword is parsed correctly",
25-
// TODO RUST-229: unskip the following tests
26-
"Single IP literal host without port",
27-
"Single IP literal host with port",
28-
"Multiple hosts (mixed formats)",
29-
"User info for single IP literal host without database",
30-
"User info for single IP literal host with database",
31-
"User info for multiple hosts with database",
3225
];
3326

3427
// TODO RUST-1896: unskip this test when openssl-tls is enabled
@@ -65,43 +58,11 @@ struct TestCase {
6558
uri: String,
6659
valid: bool,
6760
warning: Option<bool>,
68-
hosts: Option<Vec<TestServerAddress>>,
61+
hosts: Option<Vec<ServerAddress>>,
6962
auth: Option<TestAuth>,
7063
options: Option<Document>,
7164
}
7265

73-
// The connection string tests' representation of a server address. We use this indirection to avoid
74-
// deserialization failures when the tests specify an IPv6 address.
75-
//
76-
// TODO RUST-229: remove this struct and deserialize directly into ServerAddress
77-
#[derive(Debug, Deserialize)]
78-
struct TestServerAddress {
79-
#[serde(rename = "type")]
80-
host_type: String,
81-
host: String,
82-
port: Option<u16>,
83-
}
84-
85-
impl TryFrom<&TestServerAddress> for ServerAddress {
86-
type Error = Error;
87-
88-
fn try_from(test_server_address: &TestServerAddress) -> Result<Self> {
89-
if test_server_address.host_type.as_str() == "ip_literal" {
90-
return Err(ErrorKind::Internal {
91-
message: "test using ip_literal host type should be skipped".to_string(),
92-
}
93-
.into());
94-
}
95-
96-
let mut address = Self::parse(&test_server_address.host)?;
97-
if let ServerAddress::Tcp { ref mut port, .. } = address {
98-
*port = test_server_address.port;
99-
}
100-
101-
Ok(address)
102-
}
103-
}
104-
10566
#[derive(Debug, Deserialize)]
10667
#[serde(rename_all = "camelCase", deny_unknown_fields)]
10768
struct TestAuth {
@@ -138,14 +99,8 @@ async fn run_tests(path: &[&str], skipped_files: &[&str]) {
13899
let client_options = client_options_result.expect(&test_case.description);
139100

140101
if let Some(ref expected_hosts) = test_case.hosts {
141-
let expected_hosts = expected_hosts
142-
.iter()
143-
.map(TryFrom::try_from)
144-
.collect::<Result<Vec<ServerAddress>>>()
145-
.expect(&test_case.description);
146-
147102
assert_eq!(
148-
client_options.hosts, expected_hosts,
103+
&client_options.hosts, expected_hosts,
149104
"{}",
150105
test_case.description
151106
);
@@ -364,3 +319,45 @@ async fn options_enforce_min_heartbeat_frequency() {
364319

365320
Client::with_options(options).unwrap_err();
366321
}
322+
323+
#[test]
324+
fn invalid_ipv6() {
325+
// invalid hostname for ipv6
326+
let address = "[localhost]:27017";
327+
let error = ServerAddress::parse(address).unwrap_err();
328+
let message = error.message().unwrap();
329+
assert!(message.contains("invalid IPv6 address syntax"), "{message}");
330+
331+
// invalid character after hostname
332+
let address = "[::1]a";
333+
let error = ServerAddress::parse(address).unwrap_err();
334+
let message = error.message().unwrap();
335+
assert!(
336+
message.contains("the hostname can only be followed by a port"),
337+
"{message}"
338+
);
339+
340+
// missing bracket
341+
let address = "[::1:27017";
342+
let error = ServerAddress::parse(address).unwrap_err();
343+
let message = error.message().unwrap();
344+
assert!(message.contains("missing closing ']'"), "{message}");
345+
346+
// extraneous bracket
347+
let address = "[::1]:27017]";
348+
let error = ServerAddress::parse(address).unwrap_err();
349+
let message = error.message().unwrap();
350+
assert!(message.contains("the port must be an integer"), "{message}");
351+
}
352+
353+
#[cfg(not(unix))]
354+
#[test]
355+
fn unix_domain_socket_not_allowed() {
356+
let address = "address.sock";
357+
let error = ServerAddress::parse(address).unwrap_err();
358+
let message = error.message().unwrap();
359+
assert!(
360+
message.contains("not supported on this platform"),
361+
"{message}"
362+
);
363+
}

0 commit comments

Comments
 (0)