Skip to content

Commit 2137a06

Browse files
committed
Abstract some code
Signed-off-by: Tomasz Pietrek <[email protected]>
1 parent 84c6fa2 commit 2137a06

File tree

4 files changed

+66
-232
lines changed

4 files changed

+66
-232
lines changed

.config/nats.dic

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,4 @@ publish_message
196196
untagged
197197
deserialization
198198
serde's
199+
ServerInfo

async-nats/src/connector.rs

Lines changed: 25 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ pub(crate) struct ConnectorOptions {
6969
pub(crate) max_reconnects: Option<usize>,
7070
pub(crate) server_info_callback: Option<CallbackArg1<ServerInfo, ()>>,
7171
pub(crate) reconnect_server_callback:
72-
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo), ServerAddr>>,
72+
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo, usize), ServerAddr>>,
7373
}
7474

7575
/// Maintains a list of servers and establishes connections.
@@ -182,7 +182,9 @@ impl Connector {
182182
// Call the user's callback to select the server for this attempt
183183
let server_addr = if let Some(ref callback) = self.options.reconnect_server_callback
184184
{
185-
callback.call((server_addrs, server_info)).await
185+
callback
186+
.call((server_addrs, server_info, self.attempts))
187+
.await
186188
} else {
187189
// This shouldn't happen due to the outer check, but provide a fallback
188190
server_addrs
@@ -256,196 +258,28 @@ impl Connector {
256258

257259
sleep(duration).await;
258260

259-
let socket_addrs = server_addr
260-
.socket_addrs()
261-
.await
262-
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
263-
for socket_addr in socket_addrs {
264-
match self
265-
.try_connect_to(
266-
&socket_addr,
267-
server_addr.tls_required(),
268-
server_addr.clone(),
269-
)
270-
.await
271-
{
272-
Ok((server_info, mut connection)) => {
273-
// Call server_info_callback if provided
274-
if let Some(ref callback) = self.options.server_info_callback {
275-
callback.call(server_info.clone()).await;
276-
}
277-
278-
if !self.options.ignore_discovered_servers {
279-
for url in &server_info.connect_urls {
280-
let server_addr = url.parse::<ServerAddr>().map_err(|err| {
281-
ConnectError::with_source(
282-
crate::ConnectErrorKind::ServerParse,
283-
err,
284-
)
285-
})?;
286-
if !self.servers.iter().any(|(addr, _)| addr == &server_addr) {
287-
tracing::debug!(
288-
discovered_url = %url,
289-
"adding discovered server"
290-
);
291-
self.servers.push((server_addr, 0));
292-
}
293-
}
294-
}
295-
296-
let tls_required = self.options.tls_required || server_addr.tls_required();
297-
let mut connect_info = ConnectInfo {
298-
tls_required,
299-
name: self.options.name.clone(),
300-
pedantic: false,
301-
verbose: false,
302-
lang: LANG.to_string(),
303-
version: VERSION.to_string(),
304-
protocol: Protocol::Dynamic,
305-
user: self.options.auth.username.to_owned(),
306-
pass: self.options.auth.password.to_owned(),
307-
auth_token: self.options.auth.token.to_owned(),
308-
user_jwt: None,
309-
nkey: None,
310-
signature: None,
311-
echo: !self.options.no_echo,
312-
headers: true,
313-
no_responders: true,
314-
};
315-
316-
if let Some(nkey) = self.options.auth.nkey.as_ref() {
317-
match nkeys::KeyPair::from_seed(nkey.as_str()) {
318-
Ok(key_pair) => {
319-
let nonce = server_info.nonce.clone();
320-
match key_pair.sign(nonce.as_bytes()) {
321-
Ok(signed) => {
322-
connect_info.nkey = Some(key_pair.public_key());
323-
connect_info.signature =
324-
Some(URL_SAFE_NO_PAD.encode(signed));
325-
}
326-
Err(_) => {
327-
tracing::error!("failed to sign nonce with nkey");
328-
return Err(ConnectError::new(
329-
crate::ConnectErrorKind::Authentication,
330-
));
331-
}
332-
};
333-
}
334-
Err(_) => {
335-
tracing::error!("failed to create key pair from nkey seed");
336-
return Err(ConnectError::new(
337-
crate::ConnectErrorKind::Authentication,
338-
));
339-
}
340-
}
341-
}
342-
343-
if let Some(jwt) = self.options.auth.jwt.as_ref() {
344-
if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
345-
match sign_fn.call(server_info.nonce.clone()).await {
346-
Ok(sig) => {
347-
connect_info.user_jwt = Some(jwt.clone());
348-
connect_info.signature = Some(sig);
349-
}
350-
Err(_) => {
351-
tracing::error!("failed to sign nonce with JWT callback");
352-
return Err(ConnectError::new(
353-
crate::ConnectErrorKind::Authentication,
354-
));
355-
}
356-
}
357-
}
358-
}
359-
360-
if let Some(callback) = self.options.auth_callback.as_ref() {
361-
let auth = callback
362-
.call(server_info.nonce.as_bytes().to_vec())
363-
.await
364-
.map_err(|err| {
365-
tracing::error!(error = %err, "auth callback failed");
366-
ConnectError::with_source(
367-
crate::ConnectErrorKind::Authentication,
368-
err,
369-
)
370-
})?;
371-
connect_info.user = auth.username;
372-
connect_info.pass = auth.password;
373-
connect_info.user_jwt = auth.jwt;
374-
connect_info.signature = auth
375-
.signature
376-
.map(|signature| URL_SAFE_NO_PAD.encode(signature));
377-
connect_info.auth_token = auth.token;
378-
connect_info.nkey = auth.nkey;
379-
}
380-
381-
connection
382-
.easy_write_and_flush(
383-
[ClientOp::Connect(connect_info), ClientOp::Ping].iter(),
384-
)
385-
.await?;
386-
387-
match connection.read_op().await? {
388-
Some(ServerOp::Error(err)) => match err {
389-
ServerError::AuthorizationViolation => {
390-
tracing::error!(error = %err, "authorization violation");
391-
return Err(ConnectError::with_source(
392-
crate::ConnectErrorKind::AuthorizationViolation,
393-
err,
394-
));
395-
}
396-
err => {
397-
tracing::error!(error = %err, "server error during connection");
398-
return Err(ConnectError::with_source(
399-
crate::ConnectErrorKind::Io,
400-
err,
401-
));
402-
}
403-
},
404-
Some(_) => {
405-
tracing::info!(
406-
server = %server_info.port,
407-
max_payload = %server_info.max_payload,
408-
"connected successfully"
409-
);
410-
self.attempts = 0;
411-
self.connect_stats.connects.add(1, Ordering::Relaxed);
412-
self.events_tx.try_send(Event::Connected).ok();
413-
self.state_tx.send(State::Connected).ok();
414-
self.max_payload.store(
415-
server_info.max_payload,
416-
std::sync::atomic::Ordering::Relaxed,
417-
);
418-
// Save server info for next reconnect callback
419-
self.last_server_info = Some(server_info.clone());
420-
return Ok((server_info, connection));
421-
}
422-
None => {
423-
tracing::error!("connection closed unexpectedly");
424-
return Err(ConnectError::with_source(
425-
crate::ConnectErrorKind::Io,
426-
"broken pipe",
427-
));
428-
}
429-
}
430-
}
431-
432-
Err(inner) => {
433-
tracing::debug!(
434-
server = ?server_addr,
435-
error = %inner,
436-
"connection attempt failed"
437-
);
438-
error.replace(inner)
439-
}
440-
};
261+
// Use the same helper method as the callback path
262+
match self.try_connect_to_server_addr(&server_addr).await {
263+
Ok((server_info, connection)) => {
264+
return Ok((server_info, connection));
265+
}
266+
Err(err) => {
267+
tracing::debug!(
268+
server = ?server_addr,
269+
error = %err,
270+
"connection attempt failed"
271+
);
272+
error.replace(err);
273+
}
441274
}
442275
}
443276

444277
Err(error.unwrap())
445278
}
446279

447280
/// Helper method to attempt connection to a specific server address.
448-
/// Used by the reconnect_server_callback path.
281+
/// Handles DNS resolution, authentication, server info callbacks, and server discovery.
282+
/// Used by both the callback-based and default reconnection paths.
449283
async fn try_connect_to_server_addr(
450284
&mut self,
451285
server_addr: &ServerAddr,
@@ -455,6 +289,8 @@ impl Connector {
455289
.await
456290
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
457291

292+
let mut last_error = None;
293+
458294
for socket_addr in socket_addrs {
459295
match self
460296
.try_connect_to(
@@ -630,14 +466,15 @@ impl Connector {
630466
error = %err,
631467
"socket connection attempt failed"
632468
);
633-
// Continue trying other socket addresses
469+
// Save the error and continue trying other socket addresses
470+
last_error = Some(err);
634471
continue;
635472
}
636473
}
637474
}
638475

639-
// If we get here, all socket addresses failed
640-
Err(ConnectError::new(crate::ConnectErrorKind::Io))
476+
// If we get here, all socket addresses failed - return the last error
477+
Err(last_error.unwrap_or_else(|| ConnectError::new(crate::ConnectErrorKind::Io)))
641478
}
642479

643480
pub(crate) async fn try_connect_to(

async-nats/src/options.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub struct ConnectOptions {
6767
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
6868
pub(crate) server_info_callback: Option<CallbackArg1<ServerInfo, ()>>,
6969
pub(crate) reconnect_server_callback:
70-
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo), ServerAddr>>,
70+
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo, usize), ServerAddr>>,
7171
}
7272

7373
impl fmt::Debug for ConnectOptions {
@@ -943,20 +943,26 @@ impl ConnectOptions {
943943
}
944944

945945
/// Registers a callback for customizing server selection during reconnection.
946-
/// The callback receives the list of available servers and the most recent ServerInfo,
947-
/// and should return the ServerAddr to connect to.
946+
/// The callback receives:
947+
/// - List of available servers
948+
/// - Most recent ServerInfo
949+
/// - Current reconnection attempt number (starts at 1)
950+
///
951+
/// The callback should return the ServerAddr to connect to.
948952
///
949953
/// This allows fine-grained control over reconnection logic, such as:
950954
/// - Custom server selection algorithms
951955
/// - Geographic server preference
952956
/// - Dynamic server list manipulation
957+
/// - Attempt-based retry strategies
953958
///
954959
/// # Examples
955960
/// ```no_run
956961
/// # #[tokio::main]
957962
/// # async fn main() -> Result<(), async_nats::ConnectError> {
958963
/// async_nats::ConnectOptions::new()
959-
/// .reconnect_server_callback(|(servers, server_info)| async move {
964+
/// .reconnect_server_callback(|(servers, server_info, attempt)| async move {
965+
/// println!("Reconnection attempt {}", attempt);
960966
/// // Always prefer the first server in the list
961967
/// servers
962968
/// .first()
@@ -970,11 +976,11 @@ impl ConnectOptions {
970976
/// ```
971977
pub fn reconnect_server_callback<F, Fut>(mut self, cb: F) -> ConnectOptions
972978
where
973-
F: Fn((Vec<ServerAddr>, ServerInfo)) -> Fut + Send + Sync + 'static,
979+
F: Fn((Vec<ServerAddr>, ServerInfo, usize)) -> Fut + Send + Sync + 'static,
974980
Fut: Future<Output = ServerAddr> + 'static + Send + Sync,
975981
{
976982
self.reconnect_server_callback = Some(CallbackArg1::<
977-
(Vec<ServerAddr>, ServerInfo),
983+
(Vec<ServerAddr>, ServerInfo, usize),
978984
ServerAddr,
979985
>(Box::new(move |args| Box::pin(cb(args)))));
980986
self

0 commit comments

Comments
 (0)