Skip to content

Commit 84c6fa2

Browse files
committed
Improve
Signed-off-by: Tomasz Pietrek <[email protected]>
1 parent c137591 commit 84c6fa2

File tree

1 file changed

+268
-26
lines changed

1 file changed

+268
-26
lines changed

async-nats/src/connector.rs

Lines changed: 268 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -145,40 +145,86 @@ impl Connector {
145145
tracing::debug!(attempt = %self.attempts, "connecting to server");
146146
let mut error = None;
147147

148-
let mut servers = self.servers.clone();
148+
// If reconnect_server_callback is provided, use it to select server on each attempt
149+
if self.options.reconnect_server_callback.is_some() {
150+
loop {
151+
self.attempts += 1;
152+
if let Some(max_reconnects) = self.options.max_reconnects {
153+
if self.attempts > max_reconnects {
154+
tracing::error!(
155+
attempts = %self.attempts,
156+
max_reconnects = %max_reconnects,
157+
"max reconnection attempts reached"
158+
);
159+
self.events_tx
160+
.try_send(Event::ClientError(ClientError::MaxReconnects))
161+
.ok();
162+
return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
163+
}
164+
}
149165

150-
// If reconnect_server_callback is provided, let the user select the server
151-
if let Some(ref callback) = self.options.reconnect_server_callback {
152-
// Get the latest server info (use default if this is the first connection)
153-
let server_info = self.last_server_info.clone().unwrap_or_default();
166+
let duration = (self.options.reconnect_delay_callback)(self.attempts);
167+
tracing::debug!(
168+
attempt = %self.attempts,
169+
delay_ms = %duration.as_millis(),
170+
"attempting connection with user callback"
171+
);
154172

155-
// Extract just the ServerAddr from the tuple
156-
let server_addrs: Vec<ServerAddr> =
157-
servers.iter().map(|(addr, _)| addr.clone()).collect();
173+
sleep(duration).await;
158174

159-
// Call the user's callback
160-
let selected_server = callback.call((server_addrs, server_info)).await;
175+
// Get the latest server info (use default if this is the first connection)
176+
let server_info = self.last_server_info.clone().unwrap_or_default();
177+
178+
// Extract just the ServerAddr from the tuple
179+
let server_addrs: Vec<ServerAddr> =
180+
self.servers.iter().map(|(addr, _)| addr.clone()).collect();
181+
182+
// Call the user's callback to select the server for this attempt
183+
let server_addr = if let Some(ref callback) = self.options.reconnect_server_callback
184+
{
185+
callback.call((server_addrs, server_info)).await
186+
} else {
187+
// This shouldn't happen due to the outer check, but provide a fallback
188+
server_addrs
189+
.first()
190+
.cloned()
191+
.unwrap_or_else(|| "nats://localhost:4222".parse().unwrap())
192+
};
161193

162-
// Find the selected server in our list and put it first
163-
if let Some(pos) = servers
164-
.iter()
165-
.position(|(addr, _)| addr == &selected_server)
166-
{
167-
let selected = servers.remove(pos);
168-
servers.insert(0, selected);
169194
tracing::debug!(
170-
server = ?selected_server,
195+
server = ?server_addr,
171196
"user selected server via reconnect_server_callback"
172197
);
173-
} else {
174-
// Server not in our list, add it
175-
tracing::debug!(
176-
server = ?selected_server,
177-
"user selected new server via reconnect_server_callback, adding to list"
178-
);
179-
servers.insert(0, (selected_server, 0));
198+
199+
// Ensure the selected server is in our list
200+
if !self.servers.iter().any(|(addr, _)| addr == &server_addr) {
201+
tracing::debug!(
202+
server = ?server_addr,
203+
"user selected new server, adding to list"
204+
);
205+
self.servers.push((server_addr.clone(), 0));
206+
}
207+
208+
// Try to connect to the selected server
209+
match self.try_connect_to_server_addr(&server_addr).await {
210+
Ok((server_info, connection)) => {
211+
return Ok((server_info, connection));
212+
}
213+
Err(err) => {
214+
tracing::debug!(
215+
server = ?server_addr,
216+
error = %err,
217+
"connection attempt failed, will retry"
218+
);
219+
error.replace(err);
220+
}
221+
}
180222
}
181-
} else if !self.options.retain_servers_order {
223+
}
224+
225+
// Default behavior: shuffle and iterate through servers
226+
let mut servers = self.servers.clone();
227+
if !self.options.retain_servers_order {
182228
servers.shuffle(&mut thread_rng());
183229
// sort_by is stable, meaning it will retain the order for equal elements.
184230
servers.sort_by(|a, b| a.1.cmp(&b.1));
@@ -398,6 +444,202 @@ impl Connector {
398444
Err(error.unwrap())
399445
}
400446

447+
/// Helper method to attempt connection to a specific server address.
448+
/// Used by the reconnect_server_callback path.
449+
async fn try_connect_to_server_addr(
450+
&mut self,
451+
server_addr: &ServerAddr,
452+
) -> Result<(ServerInfo, Connection), ConnectError> {
453+
let socket_addrs = server_addr
454+
.socket_addrs()
455+
.await
456+
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
457+
458+
for socket_addr in socket_addrs {
459+
match self
460+
.try_connect_to(
461+
&socket_addr,
462+
server_addr.tls_required(),
463+
server_addr.clone(),
464+
)
465+
.await
466+
{
467+
Ok((server_info, mut connection)) => {
468+
// Call server_info_callback if provided
469+
if let Some(ref callback) = self.options.server_info_callback {
470+
callback.call(server_info.clone()).await;
471+
}
472+
473+
if !self.options.ignore_discovered_servers {
474+
for url in &server_info.connect_urls {
475+
let discovered_addr = url.parse::<ServerAddr>().map_err(|err| {
476+
ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
477+
})?;
478+
if !self
479+
.servers
480+
.iter()
481+
.any(|(addr, _)| addr == &discovered_addr)
482+
{
483+
tracing::debug!(
484+
discovered_url = %url,
485+
"adding discovered server"
486+
);
487+
self.servers.push((discovered_addr, 0));
488+
}
489+
}
490+
}
491+
492+
let tls_required = self.options.tls_required || server_addr.tls_required();
493+
let mut connect_info = ConnectInfo {
494+
tls_required,
495+
name: self.options.name.clone(),
496+
pedantic: false,
497+
verbose: false,
498+
lang: LANG.to_string(),
499+
version: VERSION.to_string(),
500+
protocol: Protocol::Dynamic,
501+
user: self.options.auth.username.to_owned(),
502+
pass: self.options.auth.password.to_owned(),
503+
auth_token: self.options.auth.token.to_owned(),
504+
user_jwt: None,
505+
nkey: None,
506+
signature: None,
507+
echo: !self.options.no_echo,
508+
headers: true,
509+
no_responders: true,
510+
};
511+
512+
if let Some(nkey) = self.options.auth.nkey.as_ref() {
513+
match nkeys::KeyPair::from_seed(nkey.as_str()) {
514+
Ok(key_pair) => {
515+
let nonce = server_info.nonce.clone();
516+
match key_pair.sign(nonce.as_bytes()) {
517+
Ok(signed) => {
518+
connect_info.nkey = Some(key_pair.public_key());
519+
connect_info.signature =
520+
Some(URL_SAFE_NO_PAD.encode(signed));
521+
}
522+
Err(_) => {
523+
tracing::error!("failed to sign nonce with nkey");
524+
return Err(ConnectError::new(
525+
crate::ConnectErrorKind::Authentication,
526+
));
527+
}
528+
};
529+
}
530+
Err(_) => {
531+
tracing::error!("failed to create key pair from nkey seed");
532+
return Err(ConnectError::new(
533+
crate::ConnectErrorKind::Authentication,
534+
));
535+
}
536+
}
537+
}
538+
539+
if let Some(jwt) = self.options.auth.jwt.as_ref() {
540+
if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
541+
match sign_fn.call(server_info.nonce.clone()).await {
542+
Ok(sig) => {
543+
connect_info.user_jwt = Some(jwt.clone());
544+
connect_info.signature = Some(sig);
545+
}
546+
Err(_) => {
547+
tracing::error!("failed to sign nonce with JWT callback");
548+
return Err(ConnectError::new(
549+
crate::ConnectErrorKind::Authentication,
550+
));
551+
}
552+
}
553+
}
554+
}
555+
556+
if let Some(callback) = self.options.auth_callback.as_ref() {
557+
let auth = callback
558+
.call(server_info.nonce.as_bytes().to_vec())
559+
.await
560+
.map_err(|err| {
561+
tracing::error!(error = %err, "auth callback failed");
562+
ConnectError::with_source(
563+
crate::ConnectErrorKind::Authentication,
564+
err,
565+
)
566+
})?;
567+
connect_info.user = auth.username;
568+
connect_info.pass = auth.password;
569+
connect_info.user_jwt = auth.jwt;
570+
connect_info.signature = auth
571+
.signature
572+
.map(|signature| URL_SAFE_NO_PAD.encode(signature));
573+
connect_info.auth_token = auth.token;
574+
connect_info.nkey = auth.nkey;
575+
}
576+
577+
connection
578+
.easy_write_and_flush(
579+
[ClientOp::Connect(connect_info), ClientOp::Ping].iter(),
580+
)
581+
.await?;
582+
583+
match connection.read_op().await? {
584+
Some(ServerOp::Error(err)) => match err {
585+
ServerError::AuthorizationViolation => {
586+
tracing::error!(error = %err, "authorization violation");
587+
return Err(ConnectError::with_source(
588+
crate::ConnectErrorKind::AuthorizationViolation,
589+
err,
590+
));
591+
}
592+
err => {
593+
tracing::error!(error = %err, "server error during connection");
594+
return Err(ConnectError::with_source(
595+
crate::ConnectErrorKind::Io,
596+
err,
597+
));
598+
}
599+
},
600+
Some(_) => {
601+
tracing::info!(
602+
server = %server_info.port,
603+
max_payload = %server_info.max_payload,
604+
"connected successfully"
605+
);
606+
self.attempts = 0;
607+
self.connect_stats.connects.add(1, Ordering::Relaxed);
608+
self.events_tx.try_send(Event::Connected).ok();
609+
self.state_tx.send(State::Connected).ok();
610+
self.max_payload.store(
611+
server_info.max_payload,
612+
std::sync::atomic::Ordering::Relaxed,
613+
);
614+
// Save server info for next reconnect callback
615+
self.last_server_info = Some(server_info.clone());
616+
return Ok((server_info, connection));
617+
}
618+
None => {
619+
tracing::error!("connection closed unexpectedly");
620+
return Err(ConnectError::with_source(
621+
crate::ConnectErrorKind::Io,
622+
"broken pipe",
623+
));
624+
}
625+
}
626+
}
627+
Err(err) => {
628+
tracing::debug!(
629+
server = ?server_addr,
630+
error = %err,
631+
"socket connection attempt failed"
632+
);
633+
// Continue trying other socket addresses
634+
continue;
635+
}
636+
}
637+
}
638+
639+
// If we get here, all socket addresses failed
640+
Err(ConnectError::new(crate::ConnectErrorKind::Io))
641+
}
642+
401643
pub(crate) async fn try_connect_to(
402644
&self,
403645
socket_addr: &SocketAddr,

0 commit comments

Comments
 (0)