Skip to content

Commit ee47b7e

Browse files
nanoqshsdroege
authored andcommitted
Add reunite and is_pair_of methods
1 parent f4f78cd commit ee47b7e

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/lib.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,27 @@ impl<S> WebSocketStream<S> {
335335
let receiver = WebSocketReceiver { shared };
336336
(sender, receiver)
337337
}
338+
339+
/// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
340+
/// parts back into a single stream. If both parts originate from the same
341+
/// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
342+
/// Otherwise, returns `Err` containing the provided parts.
343+
pub fn reunite(
344+
sender: WebSocketSender<S>,
345+
receiver: WebSocketReceiver<S>,
346+
) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
347+
if sender.is_pair_of(&receiver) {
348+
drop(receiver);
349+
let stream = Arc::try_unwrap(sender.shared)
350+
.ok()
351+
.expect("reunite the stream")
352+
.into_inner();
353+
354+
Ok(stream)
355+
} else {
356+
Err((sender, receiver))
357+
}
358+
}
338359
}
339360

340361
impl<T> WebSocketStream<T>
@@ -551,6 +572,7 @@ where
551572
}
552573

553574
/// The sender part of a [websocket](WebSocketStream) stream.
575+
#[derive(Debug)]
554576
pub struct WebSocketSender<S> {
555577
shared: Arc<Shared<S>>,
556578
}
@@ -575,6 +597,12 @@ impl<S> WebSocketSender<S> {
575597
{
576598
self.send(Message::Close(msg)).await
577599
}
600+
601+
/// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
602+
/// were split from the same [websocket](WebSocketStream) stream.
603+
pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
604+
Arc::ptr_eq(&self.shared, &other.shared)
605+
}
578606
}
579607

580608
#[cfg(feature = "futures-03-sink")]
@@ -602,10 +630,19 @@ where
602630
}
603631

604632
/// The receiver part of a [websocket](WebSocketStream) stream.
633+
#[derive(Debug)]
605634
pub struct WebSocketReceiver<S> {
606635
shared: Arc<Shared<S>>,
607636
}
608637

638+
impl<S> WebSocketReceiver<S> {
639+
/// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
640+
/// were split from the same [websocket](WebSocketStream) stream.
641+
pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
642+
Arc::ptr_eq(&self.shared, &other.shared)
643+
}
644+
}
645+
609646
impl<S> Stream for WebSocketReceiver<S>
610647
where
611648
S: AsyncRead + AsyncWrite + Unpin,
@@ -626,12 +663,17 @@ where
626663
}
627664
}
628665

666+
#[derive(Debug)]
629667
struct Shared<S>(Mutex<WebSocketStream<S>>);
630668

631669
impl<S> Shared<S> {
632670
fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
633671
self.0.lock().expect("lock shared stream")
634672
}
673+
674+
fn into_inner(self) -> WebSocketStream<S> {
675+
self.0.into_inner().expect("get shared stream")
676+
}
635677
}
636678

637679
#[cfg(any(

tests/communication.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async fn split_communication() {
101101
.await
102102
.expect("Client failed to connect");
103103

104-
let (tx, _rx) = stream.split();
104+
let (tx, rx) = stream.split();
105105

106106
for i in 1..10 {
107107
info!("Sending message");
@@ -115,6 +115,10 @@ async fn split_communication() {
115115
info!("Waiting for response messages");
116116
let messages = msg_rx.await.expect("Failed to receive messages");
117117
assert_eq!(messages.len(), 10);
118+
119+
assert!(tx.is_pair_of(&rx));
120+
assert!(rx.is_pair_of(&tx));
121+
WebSocketStream::reunite(tx, rx).expect("Failed to reunite the stream");
118122
}
119123

120124
#[async_std::test]

0 commit comments

Comments
 (0)