Skip to content

Commit d3b4d98

Browse files
committed
feat(close): send after close returns a ConnectionClosed error
Signed-off-by: Jad K. Haddad <jadkhaddad@gmail.com>
1 parent 8cecd61 commit d3b4d98

5 files changed

Lines changed: 46 additions & 10 deletions

File tree

src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ pub enum ReadError<I> {
8585

8686
#[derive(Debug, thiserror::Error)]
8787
pub enum WriteError<I> {
88+
/// Websocket connection is closed.
89+
///
90+
/// To close the TCP connection, you should drop the [`WebSocket`](crate::WebSocket) instance.
8891
#[error("Connection closed")]
8992
ConnectionClosed,
9093
#[error("Write frame error: {0}")]

src/functions.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::{
1313
pub struct ReadAutoCaller;
1414

1515
impl ReadAutoCaller {
16+
#[allow(clippy::too_many_arguments)]
1617
pub async fn call<'this, F, RW, Rng>(
1718
&self,
1819
auto: F,
@@ -21,6 +22,7 @@ impl ReadAutoCaller {
2122
read_state: &'this mut ReadState<'_>,
2223
write_state: &mut WriteState<'_>,
2324
fragments_state: &'this mut FragmentsState<'_>,
25+
state: &mut ConnectionState,
2426
) -> Option<Result<Option<Message<'this>>, Error<RW::Error>>>
2527
where
2628
RW: Read + Write,
@@ -37,10 +39,10 @@ impl ReadAutoCaller {
3739
let frame = match auto(frame) {
3840
Ok(on_frame) => match on_frame {
3941
OnFrame::Send(message) => {
40-
let is_close = message.is_close();
42+
state.closed = message.is_close();
4143

4244
match framez::functions::send(write_state, codec, inner, message).await {
43-
Ok(_) => match is_close {
45+
Ok(_) => match state.closed {
4446
false => return Some(Ok(None)),
4547
true => return None,
4648
},
@@ -61,6 +63,7 @@ impl ReadAutoCaller {
6163
pub struct ReadCaller;
6264

6365
impl ReadCaller {
66+
#[allow(clippy::too_many_arguments)]
6467
pub async fn call<'this, RW, Rng>(
6568
&self,
6669
_auto: (),
@@ -69,6 +72,7 @@ impl ReadCaller {
6972
read_state: &'this mut ReadState<'_>,
7073
_write_state: &mut WriteState<'_>,
7174
fragments_state: &'this mut FragmentsState<'_>,
75+
_state: &mut ConnectionState,
7276
) -> Option<Result<Option<Message<'this>>, Error<RW::Error>>>
7377
where
7478
RW: Read,
@@ -96,6 +100,10 @@ where
96100
RW: Write,
97101
Rng: RngCore,
98102
{
103+
if state.closed {
104+
return Err(Error::Write(WriteError::ConnectionClosed));
105+
}
106+
99107
state.closed = message.is_close();
100108

101109
framez::functions::send(write_state, codec, inner, message)
@@ -109,13 +117,18 @@ pub async fn send_fragmented<RW, Rng>(
109117
codec: &mut FramesCodec<Rng>,
110118
inner: &mut RW,
111119
write_state: &mut WriteState<'_>,
120+
state: &mut ConnectionState,
112121
message: Message<'_>,
113122
fragment_size: usize,
114123
) -> Result<(), Error<RW::Error>>
115124
where
116125
RW: Write,
117126
Rng: RngCore,
118127
{
128+
if state.closed {
129+
return Err(Error::Write(WriteError::ConnectionClosed));
130+
}
131+
119132
for frame in message
120133
.fragments(fragment_size)
121134
.map_err(Error::Fragmentation)?

src/macros.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ macro_rules! next {
1111
&mut $websocketz.core.framed.core.state.read,
1212
&mut $websocketz.core.framed.core.state.write,
1313
&mut $websocketz.core.fragments_state,
14+
&mut $websocketz.core.state,
1415
)
1516
.await
1617
{
@@ -44,6 +45,7 @@ macro_rules! send_fragmented {
4445
&mut $websocketz.core.framed.core.codec,
4546
&mut $websocketz.core.framed.core.inner,
4647
&mut $websocketz.core.framed.core.state.write,
48+
&mut $websocketz.core.state,
4749
$message,
4850
$fragment_size,
4951
)

src/tests.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,11 @@ mod fragmentation {
956956
}
957957

958958
mod auto {
959-
use crate::CloseFrame;
959+
960+
use crate::{
961+
CloseFrame,
962+
error::{Error, WriteError},
963+
};
960964

961965
use super::*;
962966

@@ -1048,6 +1052,17 @@ mod auto {
10481052

10491053
// Ensure the connection is closed
10501054
assert!(next!(websocketz).is_none());
1055+
1056+
// Attempt to send another message after close should fail
1057+
match websocketz.send(Message::Text("test")).await {
1058+
Ok(_) => panic!("Expected error after close, but got Ok"),
1059+
Err(error) => {
1060+
assert!(matches!(error, Error::Write(WriteError::ConnectionClosed)));
1061+
}
1062+
}
1063+
1064+
// Reading after eof should always return None
1065+
assert!(next!(websocketz).is_none());
10511066
};
10521067

10531068
let server = async move {
@@ -1064,6 +1079,14 @@ mod auto {
10641079
);
10651080

10661081
while next!(websocketz).is_some() {}
1082+
1083+
// Attempt to send another message after close should fail
1084+
match websocketz.send(Message::Text("test")).await {
1085+
Ok(_) => panic!("Expected error after close, but got Ok"),
1086+
Err(error) => {
1087+
assert!(matches!(error, Error::Write(WriteError::ConnectionClosed)));
1088+
}
1089+
}
10671090
};
10681091

10691092
tokio::join!(server, client);

src/websocket_core.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,14 @@ impl Auto {
6666
#[derive(Debug, Clone, Copy)]
6767
#[doc(hidden)]
6868
pub struct ConnectionState {
69-
/// If the user sends a close frame, we should not send a close frame back.
70-
///
7169
/// Must be set to `true` if the user sends a close frame or the other side sends a close frame.
7270
///
73-
/// If the connection is closed, every read will return `None` and every write will return a [`WriteError::ConnectionClosed`].
71+
/// If the connection is closed, every write will return a [`WriteError::ConnectionClosed`].
7472
pub closed: bool,
7573
/// Auto handling of ping/pong and close frames.
7674
auto: Auto,
7775
}
7876

79-
// TODO: Set ConnectionState.closed to true if the user sends a close frame or the other side sends a close frame.
80-
// TODO: If ConnectionState.closed: Every read will then return (None, means connection closed) and every write will return a write error with ConnectionClosed.
81-
// TODO: And then add the tests for that. If the user closes the connection or the server closed the connection, and then the user tries to read or write a frame
82-
8377
impl ConnectionState {
8478
#[inline]
8579
#[allow(clippy::new_without_default)]
@@ -632,6 +626,7 @@ impl<'buf, RW, Rng> WebSocketCore<'buf, RW, Rng> {
632626
&mut self.framed.core.codec,
633627
&mut self.framed.core.inner,
634628
&mut self.framed.core.state.write,
629+
&mut self.state,
635630
message,
636631
fragment_size,
637632
)

0 commit comments

Comments
 (0)