Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Sink and future-util optional #143

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
strategy:
matrix:
rust:
- 1.63.0
- 1.64.0

steps:
- name: Checkout sources
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock.msrv

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ version = "0.28.0"
edition = "2018"
readme = "README.md"
include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"]
rust-version = "1.63"
rust-version = "1.64"

[features]
default = ["handshake"]
default = ["handshake", "futures-03-sink"]
futures-03-sink = ["futures-util"]
handshake = ["tungstenite/handshake"]
async-std-runtime = ["async-std", "handshake"]
tokio-runtime = ["tokio", "handshake"]
Expand All @@ -37,10 +38,17 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a

[dependencies]
log = "0.4"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
futures-core = { version = "0.3", default-features = false }
atomic-waker = { version = "1.1", default-features = false }
futures-io = { version = "0.3", default-features = false, features = ["std"] }
pin-project-lite = "0.2"

[dependencies.futures-util]
optional = true
version = "0.3"
default-features = false
features = ["sink"]

[dependencies.tungstenite]
version = "0.24"
default-features = false
Expand Down
1 change: 1 addition & 0 deletions examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async fn run_test(case: u32) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
// for Sink of futures 0.3, see autobahn-server example
ws_stream.send(msg).await?;
}
}
Expand Down
4 changes: 3 additions & 1 deletion examples/autobahn-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ async fn handle_connection(peer: SocketAddr, stream: TcpStream) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await?;
// here we explicitly using futures 0.3's Sink implementation for send message
// for WebSocketStream::send, see autobahn-client example
futures::SinkExt::send(&mut ws_stream, msg).await?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/server-headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use async_tungstenite::{
use url::Url;
#[macro_use]
extern crate log;
use futures_util::{SinkExt, StreamExt};
use futures_util::StreamExt;

#[async_std::main]
async fn main() {
Expand Down
26 changes: 14 additions & 12 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use log::*;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Wake, Waker};

use atomic_waker::AtomicWaker;
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::task;
use std::sync::Arc;
use tungstenite::Error as WsError;

Expand Down Expand Up @@ -49,18 +49,20 @@ pub(crate) struct AllowStd<S> {
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
#[cfg(feature = "handshake")]
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
fn set_waker(&self, waker: &Waker);
}

#[cfg(feature = "handshake")]
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
fn set_waker(&self, waker: &Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}

impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
pub(crate) fn new(inner: S, waker: &Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
Expand All @@ -83,7 +85,7 @@ impl<S> AllowStd<S> {
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
Expand All @@ -103,11 +105,11 @@ impl<S> AllowStd<S> {
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
read_waker: AtomicWaker,
write_waker: AtomicWaker,
}

impl std::task::Wake for WakerProxy {
impl Wake for WakerProxy {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
Expand All @@ -129,10 +131,10 @@ where
#[cfg(feature = "verbose-logging")]
trace!("{}:{} AllowStd.with_context", file!(), line!());
let waker = match kind {
ContextWaker::Read => task::Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => task::Waker::from(self.write_waker_proxy.clone()),
ContextWaker::Read => Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => Waker::from(self.write_waker_proxy.clone()),
};
let mut context = task::Context::from_waker(&waker);
let mut context = Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}

Expand Down
98 changes: 89 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,16 @@ mod handshake;
))]
pub mod stream;

use std::io::{Read, Write};
use std::{
io::{Read, Write},
pin::Pin,
task::{ready, Context, Poll},
};

use compat::{cvt, AllowStd, ContextWaker};
use futures_core::stream::{FusedStream, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{
sink::{Sink, SinkExt},
stream::{FusedStream, Stream},
};
use log::*;
use std::pin::Pin;
use std::task::{Context, Poll};

#[cfg(feature = "handshake")]
use tungstenite::{
Expand Down Expand Up @@ -227,6 +226,7 @@ where
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
#[cfg(feature = "futures-03-sink")]
closing: bool,
ended: bool,
/// Tungstenite is probably ready to receive more data.
Expand Down Expand Up @@ -269,6 +269,7 @@ impl<S> WebSocketStream<S> {
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
Self {
inner: ws,
#[cfg(feature = "futures-03-sink")]
closing: false,
ended: false,
ready: true,
Expand Down Expand Up @@ -337,7 +338,7 @@ where
return Poll::Ready(None);
}

match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
#[cfg(feature = "verbose-logging")]
trace!(
"{}:{} Stream.with_context poll_next -> read()",
Expand Down Expand Up @@ -368,7 +369,8 @@ where
}
}

impl<T> Sink<Message> for WebSocketStream<T>
#[cfg(feature = "futures-03-sink")]
impl<T> futures_util::Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -446,6 +448,84 @@ where
}
}

impl<S> WebSocketStream<S> {
/// Simple send method to replace `futures_sink::Sink` (till v0.3).
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't hurt to provide this always, and should make sure that both this and the Sink impl are covered by the tests and running on the CI

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think apart from this one it's all ready

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function is already used by everything (e.g. the autobahn tests). Maybe you can change it so that the server uses Sink::send and the client the new function (together with a comment why this is done)

where
S: AsyncRead + AsyncWrite + Unpin,
{
Send::new(self, msg).await
}
}

struct Send<'a, S> {
ws: &'a mut WebSocketStream<S>,
msg: Option<Message>,
}

impl<'a, S> Send<'a, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
Self { ws, msg: Some(msg) }
}
}

impl<S> std::future::Future for Send<'_, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), WsError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.msg.is_some() {
if !self.ws.ready {
// Currently blocked so try to flush the blockage away
let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
r
});
ready!(polled)?
}

let msg = self.msg.take().expect("unreachable");
match self.ws.with_context(None, |s| s.write(msg)) {
Ok(_) => Ok(()),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued so not an error
//
// set to false here for cancellation safety of *this* Future
self.ws.ready = false;
Ok(())
}
Err(e) => {
debug!("websocket start_send error: {}", e);
Err(e)
}
}?;
}

let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
});
ready!(polled)?;

Poll::Ready(Ok(()))
}
}

#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
Expand Down
Loading