Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 120 additions & 44 deletions src/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with a `WebSocketStream`.
//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with
//! a [`WebSocketStream`](crate::WebSocketStream) or a [`WebSocketSender`](crate::WebSocketSender).

use std::{
io,
Expand All @@ -10,99 +11,174 @@ use futures_core::stream::Stream;

use crate::{tungstenite::Bytes, Message, WsError};

/// Treat a `WebSocketStream` as an `AsyncWrite` implementation.
/// Treat a websocket [sender](Sender) as an `AsyncWrite` implementation.
///
/// Every write sends a binary message. If you want to group writes together, consider wrapping
/// this with a `BufWriter`.
#[cfg(feature = "futures-03-sink")]
#[derive(Debug)]
pub struct ByteWriter<S>(S);
pub struct ByteWriter<S> {
sender: S,
state: State,
}

#[cfg(feature = "futures-03-sink")]
impl<S> ByteWriter<S> {
/// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message`
/// Create a new `ByteWriter` from a [sender](Sender) that accepts a websocket [`Message`].
#[inline(always)]
pub fn new(s: S) -> Self {
Self(s)
pub fn new(sender: S) -> Self
where
S: Sender,
{
Self {
sender,
state: State::Open,
}
}

/// Get the underlying `Sink` back.
/// Get the underlying [sender](Sender) back.
#[inline(always)]
pub fn into_inner(self) -> S {
self.0
self.sender
}
}

#[derive(Debug)]
enum State {
Open,
Closing(Option<Message>),
}

impl State {
fn close(&mut self) -> &mut Option<Message> {
match self {
State::Open => {
*self = State::Closing(Some(Message::Close(None)));
if let State::Closing(msg) = self {
msg
} else {
unreachable!()
}
}
State::Closing(msg) => msg,
}
}
}

/// Sends bytes as a websocket [`Message`].
///
/// It's implemented for [`WebSocketStream`](crate::WebSocketStream)
/// and [`WebSocketSender`](crate::WebSocketSender).
/// It's also implemeted for every `Sink` type that accepts
/// a websocket [`Message`] and returns [`WsError`] type as
/// an error when `futures-03-sink` feature is enabled.
pub trait Sender: private::SealedSender {}

pub(crate) mod private {
use super::*;

pub trait SealedSender {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>>;

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>>;

fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
msg: &mut Option<Message>,
) -> Poll<Result<(), WsError>>;
}

impl<S> Sender for S where S: SealedSender {}
}

#[cfg(feature = "futures-03-sink")]
fn poll_write_helper<S>(
mut s: Pin<&mut ByteWriter<S>>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>
impl<S> private::SealedSender for S
where
S: futures_util::Sink<Message, Error = WsError> + Unpin,
{
match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let len = buf.len();
let msg = Message::binary(buf.to_owned());
Poll::Ready(
Pin::new(&mut s.0)
.start_send(msg)
.map_err(convert_err)
.map(|()| len),
)
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>> {
use std::task::ready;

ready!(self.as_mut().poll_ready(cx))?;
let len = buf.len();
self.start_send(Message::binary(buf.to_owned()))?;
Poll::Ready(Ok(len))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
<S as futures_util::Sink<_>>::poll_flush(self, cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
_: &mut Option<Message>,
) -> Poll<Result<(), WsError>> {
<S as futures_util::Sink<_>>::poll_close(self, cx)
}
}

#[cfg(feature = "futures-03-sink")]
impl<S> futures_io::AsyncWrite for ByteWriter<S>
where
S: futures_util::Sink<Message, Error = WsError> + Unpin,
S: Sender + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_write_helper(self, cx, buf)
<S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
.map_err(convert_err)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
.map_err(convert_err)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
let msg = me.state.close();
<S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
.map_err(convert_err)
}
}

#[cfg(feature = "futures-03-sink")]
#[cfg(feature = "tokio-runtime")]
impl<S> tokio::io::AsyncWrite for ByteWriter<S>
where
S: futures_util::Sink<Message, Error = WsError> + Unpin,
S: Sender + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_write_helper(self, cx, buf)
<S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
.map_err(convert_err)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
.map_err(convert_err)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
let msg = me.state.close();
<S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
.map_err(convert_err)
}
}

/// Treat a `WebSocketStream` as an `AsyncRead` implementation.
/// Treat a websocket [stream](Stream) as an `AsyncRead` implementation.
///
/// This also works with any other `Stream` of `Message`, such as a `SplitStream`.
///
Expand All @@ -115,7 +191,7 @@ pub struct ByteReader<S> {
}

impl<S> ByteReader<S> {
/// Create a new `ByteReader` from a `Stream` that returns a WebSocket `Message`
/// Create a new `ByteReader` from a [stream](Stream) that returns a WebSocket [`Message`].
#[inline(always)]
pub fn new(stream: S) -> Self {
Self {
Expand Down
81 changes: 72 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ pub mod tokio;

pub mod bytes;
pub use bytes::ByteReader;
#[cfg(feature = "futures-03-sink")]
pub use bytes::ByteWriter;

use tungstenite::protocol::CloseFrame;
Expand Down Expand Up @@ -358,9 +357,9 @@ impl<S> WebSocketStream<S> {
}
}

impl<T> WebSocketStream<T>
impl<S> WebSocketStream<S>
where
T: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
#[cfg(feature = "verbose-logging")]
Expand Down Expand Up @@ -465,9 +464,9 @@ where
}
}

impl<T> Stream for WebSocketStream<T>
impl<S> Stream for WebSocketStream<S>
where
T: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<Message, WsError>;

Expand All @@ -476,19 +475,19 @@ where
}
}

impl<T> FusedStream for WebSocketStream<T>
impl<S> FusedStream for WebSocketStream<S>
where
T: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
fn is_terminated(&self) -> bool {
self.ended
}
}

#[cfg(feature = "futures-03-sink")]
impl<T> futures_util::Sink<Message> for WebSocketStream<T>
impl<S> futures_util::Sink<Message> for WebSocketStream<S>
where
T: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
type Error = WsError;

Expand All @@ -509,6 +508,37 @@ where
}
}

#[cfg(not(feature = "futures-03-sink"))]
impl<S> bytes::private::SealedSender for WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>> {
let me = self.get_mut();
ready!(me.poll_ready(cx))?;
let len = buf.len();
me.start_send(Message::binary(buf.to_owned()))?;
Poll::Ready(Ok(len))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.get_mut().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
msg: &mut Option<Message>,
) -> Poll<Result<(), WsError>> {
let me = self.get_mut();
send_helper(me, msg, cx)
}
}

impl<S> WebSocketStream<S> {
/// Simple send method to replace `futures_sink::Sink` (till v0.3).
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
Expand Down Expand Up @@ -629,6 +659,39 @@ where
}
}

#[cfg(not(feature = "futures-03-sink"))]
impl<S> bytes::private::SealedSender for WebSocketSender<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>> {
let me = self.get_mut();
let mut ws = me.shared.lock();
ready!(ws.poll_ready(cx))?;
let len = buf.len();
ws.start_send(Message::binary(buf.to_owned()))?;
Poll::Ready(Ok(len))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.shared.lock().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
msg: &mut Option<Message>,
) -> Poll<Result<(), WsError>> {
let me = self.get_mut();
let mut ws = me.shared.lock();
send_helper(&mut ws, msg, cx)
}
}

/// The receiver part of a [websocket](WebSocketStream) stream.
#[derive(Debug)]
pub struct WebSocketReceiver<S> {
Expand Down