Skip to content

Commit c137591

Browse files
committed
Add reconnect callbacks
Signed-off-by: Tomasz Pietrek <[email protected]>
1 parent 974720b commit c137591

File tree

4 files changed

+311
-2
lines changed

4 files changed

+311
-2
lines changed

async-nats/src/connector.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ pub(crate) struct ConnectorOptions {
6767
pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
6868
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
6969
pub(crate) max_reconnects: Option<usize>,
70+
pub(crate) server_info_callback: Option<CallbackArg1<ServerInfo, ()>>,
71+
pub(crate) reconnect_server_callback:
72+
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo), ServerAddr>>,
7073
}
7174

7275
/// Maintains a list of servers and establishes connections.
@@ -79,6 +82,7 @@ pub(crate) struct Connector {
7982
pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
8083
pub(crate) state_tx: tokio::sync::watch::Sender<State>,
8184
pub(crate) max_payload: Arc<AtomicUsize>,
85+
last_server_info: Option<ServerInfo>,
8286
}
8387

8488
pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
@@ -110,6 +114,7 @@ impl Connector {
110114
state_tx,
111115
max_payload,
112116
connect_stats,
117+
last_server_info: None,
113118
})
114119
}
115120

@@ -141,7 +146,39 @@ impl Connector {
141146
let mut error = None;
142147

143148
let mut servers = self.servers.clone();
144-
if !self.options.retain_servers_order {
149+
150+
// If reconnect_server_callback is provided, let the user select the server
151+
if let Some(ref callback) = self.options.reconnect_server_callback {
152+
// Get the latest server info (use default if this is the first connection)
153+
let server_info = self.last_server_info.clone().unwrap_or_default();
154+
155+
// Extract just the ServerAddr from the tuple
156+
let server_addrs: Vec<ServerAddr> =
157+
servers.iter().map(|(addr, _)| addr.clone()).collect();
158+
159+
// Call the user's callback
160+
let selected_server = callback.call((server_addrs, server_info)).await;
161+
162+
// Find the selected server in our list and put it first
163+
if let Some(pos) = servers
164+
.iter()
165+
.position(|(addr, _)| addr == &selected_server)
166+
{
167+
let selected = servers.remove(pos);
168+
servers.insert(0, selected);
169+
tracing::debug!(
170+
server = ?selected_server,
171+
"user selected server via reconnect_server_callback"
172+
);
173+
} else {
174+
// Server not in our list, add it
175+
tracing::debug!(
176+
server = ?selected_server,
177+
"user selected new server via reconnect_server_callback, adding to list"
178+
);
179+
servers.insert(0, (selected_server, 0));
180+
}
181+
} else if !self.options.retain_servers_order {
145182
servers.shuffle(&mut thread_rng());
146183
// sort_by is stable, meaning it will retain the order for equal elements.
147184
servers.sort_by(|a, b| a.1.cmp(&b.1));
@@ -187,6 +224,11 @@ impl Connector {
187224
.await
188225
{
189226
Ok((server_info, mut connection)) => {
227+
// Call server_info_callback if provided
228+
if let Some(ref callback) = self.options.server_info_callback {
229+
callback.call(server_info.clone()).await;
230+
}
231+
190232
if !self.options.ignore_discovered_servers {
191233
for url in &server_info.connect_urls {
192234
let server_addr = url.parse::<ServerAddr>().map_err(|err| {
@@ -327,6 +369,8 @@ impl Connector {
327369
server_info.max_payload,
328370
std::sync::atomic::Ordering::Relaxed,
329371
);
372+
// Save server info for next reconnect callback
373+
self.last_server_info = Some(server_info.clone());
330374
return Ok((server_info, connection));
331375
}
332376
None => {

async-nats/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,8 @@ pub async fn connect_with_options<A: ToServerAddrs>(
10251025
reconnect_delay_callback: options.reconnect_delay_callback,
10261026
auth_callback: options.auth_callback,
10271027
max_reconnects: options.max_reconnects,
1028+
server_info_callback: options.server_info_callback,
1029+
reconnect_server_callback: options.reconnect_server_callback,
10281030
},
10291031
events_tx,
10301032
state_tx,

async-nats/src/options.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
use crate::auth::Auth;
1515
use crate::connector;
16-
use crate::{Client, ConnectError, Event, ToServerAddrs};
16+
use crate::{Client, ConnectError, Event, ServerAddr, ServerInfo, ToServerAddrs};
1717
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
1818
use base64::engine::Engine;
1919
use futures_util::Future;
@@ -65,6 +65,9 @@ pub struct ConnectOptions {
6565
pub(crate) read_buffer_capacity: u16,
6666
pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
6767
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
68+
pub(crate) server_info_callback: Option<CallbackArg1<ServerInfo, ()>>,
69+
pub(crate) reconnect_server_callback:
70+
Option<CallbackArg1<(Vec<ServerAddr>, ServerInfo), ServerAddr>>,
6871
}
6972

7073
impl fmt::Debug for ConnectOptions {
@@ -117,6 +120,8 @@ impl Default for ConnectOptions {
117120
}),
118121
auth: Default::default(),
119122
auth_callback: None,
123+
server_info_callback: None,
124+
reconnect_server_callback: None,
120125
}
121126
}
122127
}
@@ -909,6 +914,71 @@ impl ConnectOptions {
909914
self.read_buffer_capacity = size;
910915
self
911916
}
917+
918+
/// Registers a callback for when the server sends new INFO messages.
919+
/// This typically happens when the server discovers new servers in the cluster.
920+
///
921+
/// # Examples
922+
/// ```no_run
923+
/// # #[tokio::main]
924+
/// # async fn main() -> Result<(), async_nats::ConnectError> {
925+
/// async_nats::ConnectOptions::new()
926+
/// .server_info_callback(|server_info| async move {
927+
/// println!("Server info updated: {:?}", server_info);
928+
/// })
929+
/// .connect("demo.nats.io")
930+
/// .await?;
931+
/// # Ok(())
932+
/// # }
933+
/// ```
934+
pub fn server_info_callback<F, Fut>(mut self, cb: F) -> ConnectOptions
935+
where
936+
F: Fn(ServerInfo) -> Fut + Send + Sync + 'static,
937+
Fut: Future<Output = ()> + 'static + Send + Sync,
938+
{
939+
self.server_info_callback = Some(CallbackArg1::<ServerInfo, ()>(Box::new(
940+
move |server_info| Box::pin(cb(server_info)),
941+
)));
942+
self
943+
}
944+
945+
/// Registers a callback for customizing server selection during reconnection.
946+
/// The callback receives the list of available servers and the most recent ServerInfo,
947+
/// and should return the ServerAddr to connect to.
948+
///
949+
/// This allows fine-grained control over reconnection logic, such as:
950+
/// - Custom server selection algorithms
951+
/// - Geographic server preference
952+
/// - Dynamic server list manipulation
953+
///
954+
/// # Examples
955+
/// ```no_run
956+
/// # #[tokio::main]
957+
/// # async fn main() -> Result<(), async_nats::ConnectError> {
958+
/// async_nats::ConnectOptions::new()
959+
/// .reconnect_server_callback(|(servers, server_info)| async move {
960+
/// // Always prefer the first server in the list
961+
/// servers
962+
/// .first()
963+
/// .cloned()
964+
/// .unwrap_or_else(|| "nats://localhost:4222".parse().unwrap())
965+
/// })
966+
/// .connect("demo.nats.io")
967+
/// .await?;
968+
/// # Ok(())
969+
/// # }
970+
/// ```
971+
pub fn reconnect_server_callback<F, Fut>(mut self, cb: F) -> ConnectOptions
972+
where
973+
F: Fn((Vec<ServerAddr>, ServerInfo)) -> Fut + Send + Sync + 'static,
974+
Fut: Future<Output = ServerAddr> + 'static + Send + Sync,
975+
{
976+
self.reconnect_server_callback = Some(CallbackArg1::<
977+
(Vec<ServerAddr>, ServerInfo),
978+
ServerAddr,
979+
>(Box::new(move |args| Box::pin(cb(args)))));
980+
self
981+
}
912982
}
913983

914984
pub(crate) type AsyncCallbackArg1<A, T> =

async-nats/tests/client_tests.rs

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,4 +1181,197 @@ mod client {
11811181
.await
11821182
.expect("Expected to be able to create a new client");
11831183
}
1184+
1185+
#[tokio::test]
1186+
async fn test_reconnect_server_callback() {
1187+
// This test validates that the reconnect_server_callback is invoked
1188+
// and that the client respects the server selection from the callback.
1189+
use std::sync::Arc;
1190+
use tokio::sync::Mutex;
1191+
1192+
let server = nats_server::run_basic_server();
1193+
let correct_addr = server.client_url();
1194+
1195+
// Parse the correct server address
1196+
let correct_server: ServerAddr = correct_addr.parse().unwrap();
1197+
let correct_server_for_callback = correct_server.clone();
1198+
1199+
// Create a fake/non-existent server address for initial attempt
1200+
let fake_server: ServerAddr = "nats://localhost:9999".parse().unwrap();
1201+
1202+
// Track callback invocations
1203+
let callback_invoked = Arc::new(Mutex::new(false));
1204+
let callback_invoked_clone = callback_invoked.clone();
1205+
1206+
// Track connection events
1207+
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(10);
1208+
1209+
let client = ConnectOptions::new()
1210+
.reconnect_server_callback(move |(servers, _server_info)| {
1211+
let correct = correct_server_for_callback.clone();
1212+
let invoked = callback_invoked_clone.clone();
1213+
async move {
1214+
*invoked.lock().await = true;
1215+
1216+
// First time: return the correct server
1217+
// (The fake server should fail, triggering reconnect)
1218+
let has_fake_server = servers.iter().any(|s| {
1219+
// Check if this server matches our fake one
1220+
format!("{:?}", s).contains("9999")
1221+
});
1222+
1223+
if has_fake_server {
1224+
// Return the correct server
1225+
correct
1226+
} else {
1227+
// Return the first available server
1228+
servers.first().cloned().unwrap_or(correct)
1229+
}
1230+
}
1231+
})
1232+
.event_callback(move |event| {
1233+
let tx = event_tx.clone();
1234+
async move {
1235+
if let Event::Connected = event {
1236+
tx.send(event).await.ok();
1237+
}
1238+
}
1239+
})
1240+
.retry_on_initial_connect()
1241+
.connect(vec![fake_server, correct_server.clone()])
1242+
.await
1243+
.unwrap();
1244+
1245+
// Wait for connection to be established
1246+
tokio::time::timeout(Duration::from_secs(10), async {
1247+
while let Some(event) = event_rx.recv().await {
1248+
if matches!(event, Event::Connected) {
1249+
break;
1250+
}
1251+
}
1252+
})
1253+
.await
1254+
.expect("Client should connect within timeout");
1255+
1256+
// Verify callback was invoked
1257+
assert!(
1258+
*callback_invoked.lock().await,
1259+
"Callback should have been invoked"
1260+
);
1261+
1262+
// Verify we can publish and subscribe (connection is working)
1263+
let mut subscriber = client.subscribe("test").await.unwrap();
1264+
client.publish("test", "data".into()).await.unwrap();
1265+
client.flush().await.unwrap();
1266+
1267+
tokio::time::timeout(Duration::from_secs(5), subscriber.next())
1268+
.await
1269+
.expect("Should receive message within timeout")
1270+
.expect("Should receive a message");
1271+
}
1272+
1273+
#[tokio::test]
1274+
async fn test_server_info_callback() {
1275+
// This test validates that the server_info_callback is invoked
1276+
// when the server sends INFO messages.
1277+
use std::sync::Arc;
1278+
use tokio::sync::Mutex;
1279+
1280+
let server = nats_server::run_basic_server();
1281+
1282+
let callback_invoked = Arc::new(Mutex::new(false));
1283+
let callback_invoked_clone = callback_invoked.clone();
1284+
1285+
let _client = ConnectOptions::new()
1286+
.server_info_callback(move |server_info| {
1287+
let invoked = callback_invoked_clone.clone();
1288+
async move {
1289+
*invoked.lock().await = true;
1290+
// Verify we have some basic server info
1291+
assert!(!server_info.server_id.is_empty() || server_info.port > 0);
1292+
}
1293+
})
1294+
.connect(server.client_url())
1295+
.await
1296+
.unwrap();
1297+
1298+
// Give callback time to be invoked
1299+
tokio::time::sleep(Duration::from_millis(500)).await;
1300+
1301+
// Verify callback was invoked
1302+
assert!(
1303+
*callback_invoked.lock().await,
1304+
"Server info callback should have been invoked"
1305+
);
1306+
}
1307+
1308+
#[tokio::test]
1309+
async fn test_reconnect_with_force_reconnect() {
1310+
// Test that reconnect_server_callback works correctly during force_reconnect
1311+
use std::sync::Arc;
1312+
use tokio::sync::Mutex;
1313+
1314+
let server = nats_server::run_basic_server();
1315+
let server_addr: ServerAddr = server.client_url().parse().unwrap();
1316+
1317+
let callback_count = Arc::new(Mutex::new(0));
1318+
let callback_count_clone = callback_count.clone();
1319+
1320+
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(10);
1321+
1322+
let client = ConnectOptions::new()
1323+
.reconnect_server_callback(move |(servers, _)| {
1324+
let count = callback_count_clone.clone();
1325+
async move {
1326+
*count.lock().await += 1;
1327+
servers.first().cloned().unwrap()
1328+
}
1329+
})
1330+
.event_callback(move |event| {
1331+
let tx = event_tx.clone();
1332+
async move {
1333+
tx.send(event).await.ok();
1334+
}
1335+
})
1336+
.connect(server_addr.clone())
1337+
.await
1338+
.unwrap();
1339+
1340+
// Wait for initial connection
1341+
tokio::time::timeout(Duration::from_secs(5), async {
1342+
while let Some(event) = event_rx.recv().await {
1343+
if matches!(event, Event::Connected) {
1344+
break;
1345+
}
1346+
}
1347+
})
1348+
.await
1349+
.expect("Initial connection should succeed");
1350+
1351+
// Force a reconnect
1352+
client.force_reconnect().await.unwrap();
1353+
1354+
// Wait for reconnection
1355+
tokio::time::timeout(Duration::from_secs(5), async {
1356+
loop {
1357+
if let Some(event) = event_rx.recv().await {
1358+
match event {
1359+
Event::Disconnected => continue,
1360+
Event::Connected => break,
1361+
_ => continue,
1362+
}
1363+
}
1364+
}
1365+
})
1366+
.await
1367+
.expect("Reconnection should succeed");
1368+
1369+
// Verify callback was invoked at least once (possibly twice - initial + reconnect)
1370+
let count = *callback_count.lock().await;
1371+
assert!(
1372+
count >= 1,
1373+
"Callback should have been invoked at least once, got: {}",
1374+
count
1375+
);
1376+
}
11841377
}

0 commit comments

Comments
 (0)