Skip to content

Commit c6ee5f0

Browse files
committed
Add SecWebsocketExtensions
1 parent 2d9a5c4 commit c6ee5f0

File tree

4 files changed

+381
-1
lines changed

4 files changed

+381
-1
lines changed

src/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ pub use self::referer::Referer;
5656
pub use self::referrer_policy::ReferrerPolicy;
5757
pub use self::retry_after::RetryAfter;
5858
pub use self::sec_websocket_accept::SecWebsocketAccept;
59+
pub use self::sec_websocket_extensions::{SecWebsocketExtensions, WebsocketExtension};
5960
pub use self::sec_websocket_key::SecWebsocketKey;
6061
pub use self::sec_websocket_version::SecWebsocketVersion;
6162
pub use self::server::Server;
@@ -175,6 +176,7 @@ mod referer;
175176
mod referrer_policy;
176177
mod retry_after;
177178
mod sec_websocket_accept;
179+
mod sec_websocket_extensions;
178180
mod sec_websocket_key;
179181
mod sec_websocket_version;
180182
mod server;
Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
use std::convert::TryFrom;
2+
3+
use bytes::BytesMut;
4+
use http::header::SEC_WEBSOCKET_EXTENSIONS;
5+
6+
use util::{Comma, FlatCsv, HeaderValueString, SemiColon};
7+
use {Error, Header, HeaderValue};
8+
9+
/// `Sec-WebSocket-Extensions` header, defined in [RFC6455][RFC6455_11.3.2]
10+
///
11+
/// The `Sec-WebSocket-Extensions` header field is used in the WebSocket
12+
/// opening handshake. It is initially sent from the client to the
13+
/// server, and then subsequently sent from the server to the client, to
14+
/// agree on a set of protocol-level extensions to use for the duration
15+
/// of the connection.
16+
///
17+
/// ## ABNF
18+
///
19+
/// ```text
20+
/// Sec-WebSocket-Extensions = extension-list
21+
/// extension-list = 1#extension
22+
/// extension = extension-token *( ";" extension-param )
23+
/// extension-token = registered-token
24+
/// registered-token = token
25+
/// extension-param = token [ "=" (token | quoted-string) ]
26+
/// ```
27+
///
28+
/// ## Example Values
29+
///
30+
/// * `permessage-deflate` (defined in [RFC7692][RFC7692_7])
31+
/// * `permessage-deflate; server_max_window_bits=10`
32+
/// * `permessage-deflate; server_max_window_bits=10, permessage-deflate`
33+
///
34+
/// ## Example
35+
///
36+
/// ```rust
37+
/// # extern crate headers;
38+
/// use headers::SecWebsocketExtensions;
39+
///
40+
/// let extensions = SecWebsocketExtensions::from_static("permessage-deflate");
41+
/// ```
42+
///
43+
/// ## Splitting and Combining
44+
///
45+
/// Note that `Sec-WebSocket-Extensions` may be split or combined across multiple headers.
46+
/// The following are equivalent:
47+
/// ```text
48+
/// Sec-WebSocket-Extensions: foo
49+
/// Sec-WebSocket-Extensions: bar; baz=2
50+
/// ```
51+
/// ```text
52+
/// Sec-WebSocket-Extensions: foo, bar; baz=2
53+
/// ```
54+
///
55+
/// `SecWebsocketExtensions` splits extensions when decoding and combines them into a single
56+
/// value when encoding.
57+
///
58+
/// [RFC6455_11.3.2]: https://tools.ietf.org/html/rfc6455#section-11.3.2
59+
/// [RFC7692_7]: https://tools.ietf.org/html/rfc7692#section-7
60+
#[derive(Clone, Debug, Eq, PartialEq)]
61+
pub struct SecWebsocketExtensions(pub Vec<WebsocketExtension>);
62+
63+
impl Header for SecWebsocketExtensions {
64+
fn name() -> &'static ::HeaderName {
65+
&SEC_WEBSOCKET_EXTENSIONS
66+
}
67+
68+
fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> {
69+
let extensions = values
70+
.cloned()
71+
.flat_map(|v| {
72+
FlatCsv::<Comma>::from(v)
73+
.iter()
74+
.map(WebsocketExtension::try_from)
75+
.collect::<Vec<_>>()
76+
})
77+
.collect::<Result<Vec<_>, _>>()?;
78+
if extensions.is_empty() {
79+
Err(Error::invalid())
80+
} else {
81+
Ok(SecWebsocketExtensions(extensions))
82+
}
83+
}
84+
85+
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
86+
if !self.is_empty() {
87+
values.extend(std::iter::once(self.to_value()));
88+
}
89+
}
90+
}
91+
92+
impl SecWebsocketExtensions {
93+
/// Construct a `SecWebSocketExtensions` from a static string.
94+
///
95+
/// ## Panic
96+
///
97+
/// Panics if the static string is not a valid extensions valie.
98+
pub fn from_static(s: &'static str) -> Self {
99+
let value = HeaderValue::from_static(s);
100+
Self::try_from(&value).expect("valid static string")
101+
}
102+
103+
/// Convert this `SecWebsocketExtensions` to a single `HeaderValue`.
104+
pub fn to_value(&self) -> HeaderValue {
105+
let values = self.0.iter().map(HeaderValue::from).collect::<FlatCsv>();
106+
HeaderValue::from(&values)
107+
}
108+
109+
/// An iterator over the `WebsocketExtension`s in `SecWebsocketExtensions` header(s).
110+
pub fn iter(&self) -> impl Iterator<Item = &WebsocketExtension> {
111+
self.0.iter()
112+
}
113+
114+
/// Get the number of extensions.
115+
pub fn len(&self) -> usize {
116+
self.0.len()
117+
}
118+
119+
/// Returns `true` if headers contain no extensions.
120+
pub fn is_empty(&self) -> bool {
121+
self.0.is_empty()
122+
}
123+
}
124+
125+
impl TryFrom<&str> for SecWebsocketExtensions {
126+
type Error = Error;
127+
128+
fn try_from(value: &str) -> Result<Self, Self::Error> {
129+
let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?;
130+
Self::try_from(&value)
131+
}
132+
}
133+
134+
impl TryFrom<&HeaderValue> for SecWebsocketExtensions {
135+
type Error = Error;
136+
137+
fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
138+
let mut values = std::iter::once(value);
139+
SecWebsocketExtensions::decode(&mut values)
140+
}
141+
}
142+
143+
/// A WebSocket extension containing the name and parameters.
144+
#[derive(Clone, Debug, Eq, PartialEq)]
145+
pub struct WebsocketExtension {
146+
name: HeaderValueString,
147+
params: Vec<(HeaderValueString, Option<HeaderValueString>)>,
148+
}
149+
150+
impl WebsocketExtension {
151+
/// Construct a `WebSocketExtension` from a static string.
152+
///
153+
/// ## Panics
154+
///
155+
/// This function panics if the argument is invalid.
156+
pub fn from_static(src: &'static str) -> Self {
157+
WebsocketExtension::try_from(HeaderValue::from_static(src)).expect("valid static value")
158+
}
159+
160+
/// Get the name of the extension.
161+
pub fn name(&self) -> &str {
162+
self.name.as_str()
163+
}
164+
165+
/// An iterator over the parameters of this extension.
166+
pub fn params(&self) -> impl Iterator<Item = (&str, Option<&str>)> {
167+
self.params
168+
.iter()
169+
.map(|(k, v)| (k.as_str(), v.as_ref().map(|v| v.as_str())))
170+
}
171+
}
172+
173+
impl TryFrom<&str> for WebsocketExtension {
174+
type Error = Error;
175+
176+
fn try_from(value: &str) -> Result<Self, Self::Error> {
177+
if value.is_empty() {
178+
Err(Error::invalid())
179+
} else {
180+
let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?;
181+
WebsocketExtension::try_from(value)
182+
}
183+
}
184+
}
185+
186+
impl TryFrom<HeaderValue> for WebsocketExtension {
187+
type Error = Error;
188+
189+
fn try_from(value: HeaderValue) -> Result<Self, Self::Error> {
190+
let csv = FlatCsv::<Comma>::from(value);
191+
// More than one extension was found
192+
if csv.iter().count() > 1 {
193+
return Err(Error::invalid());
194+
}
195+
196+
let params = FlatCsv::<SemiColon>::from(csv.value);
197+
let mut params_iter = params.iter();
198+
let name = params_iter
199+
.next()
200+
.ok_or_else(Error::invalid)
201+
.and_then(HeaderValueString::from_str)?;
202+
let params = params_iter
203+
.map(|p| {
204+
let mut kv = p.splitn(2, '=');
205+
let key = kv
206+
.next()
207+
.ok_or_else(Error::invalid)
208+
.and_then(HeaderValueString::from_str)?;
209+
let val = kv
210+
.next()
211+
.map(|v| HeaderValueString::from_str(v.trim_matches('"')))
212+
.transpose()?;
213+
Ok((key, val))
214+
})
215+
.collect::<Result<Vec<_>, _>>()?;
216+
Ok(WebsocketExtension { name, params })
217+
}
218+
}
219+
220+
impl From<&WebsocketExtension> for HeaderValue {
221+
fn from(extension: &WebsocketExtension) -> Self {
222+
let mut buf = BytesMut::from(extension.name.as_str().as_bytes());
223+
for (key, val) in &extension.params {
224+
buf.extend_from_slice(b"; ");
225+
buf.extend_from_slice(key.as_str().as_bytes());
226+
if let Some(val) = val {
227+
buf.extend_from_slice(b"=");
228+
buf.extend_from_slice(val.as_str().as_bytes());
229+
}
230+
}
231+
232+
HeaderValue::from_maybe_shared(buf.freeze())
233+
.expect("semicolon separated HeaderValueStrings are valid")
234+
}
235+
}
236+
237+
#[cfg(test)]
238+
mod tests {
239+
use super::super::{test_decode, test_encode};
240+
use super::*;
241+
242+
#[test]
243+
fn extensions_decode() {
244+
let extensions =
245+
test_decode::<SecWebsocketExtensions>(&["key1; val1", "key2; val2"]).unwrap();
246+
assert_eq!(extensions.0.len(), 2);
247+
assert_eq!(
248+
extensions.0[0],
249+
WebsocketExtension::try_from("key1; val1").unwrap()
250+
);
251+
assert_eq!(
252+
extensions.0[1],
253+
WebsocketExtension::try_from("key2; val2").unwrap()
254+
);
255+
256+
assert_eq!(test_decode::<SecWebsocketExtensions>(&[""]), None);
257+
}
258+
259+
#[test]
260+
fn extensions_decode_split() {
261+
// Split each extension into separate headers
262+
let extensions =
263+
test_decode::<SecWebsocketExtensions>(&["key1; val1, key2; val2", "key3; val3"])
264+
.unwrap();
265+
assert_eq!(extensions.0.len(), 3);
266+
assert_eq!(
267+
extensions.0[0],
268+
WebsocketExtension::try_from("key1; val1").unwrap()
269+
);
270+
assert_eq!(
271+
extensions.0[1],
272+
WebsocketExtension::try_from("key2; val2").unwrap()
273+
);
274+
assert_eq!(
275+
extensions.0[2],
276+
WebsocketExtension::try_from("key3; val3").unwrap()
277+
);
278+
}
279+
280+
#[test]
281+
fn extensions_encode() {
282+
let extensions =
283+
SecWebsocketExtensions(vec![WebsocketExtension::from_static("foo; bar; baz=1")]);
284+
let headers = test_encode(extensions);
285+
let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter();
286+
assert_eq!(vals.next().unwrap(), "foo; bar; baz=1");
287+
assert_eq!(vals.next(), None);
288+
289+
let extensions = SecWebsocketExtensions(vec![]);
290+
let headers = test_encode(extensions);
291+
let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter();
292+
assert_eq!(vals.next(), None);
293+
}
294+
295+
#[test]
296+
fn extensions_encode_combine() {
297+
// Multiple extensions are combined into a single header
298+
let extensions = SecWebsocketExtensions(vec![
299+
WebsocketExtension::from_static("foo1; bar"),
300+
WebsocketExtension::from_static("foo2; bar"),
301+
WebsocketExtension::from_static("baz; quux"),
302+
]);
303+
let headers = test_encode(extensions);
304+
let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter();
305+
assert_eq!(vals.next().unwrap(), "foo1; bar, foo2; bar, baz; quux");
306+
assert_eq!(vals.next(), None);
307+
}
308+
309+
#[test]
310+
fn extensions_iter() {
311+
let extensions = SecWebsocketExtensions(vec![
312+
WebsocketExtension::from_static("foo; bar1; bar2=3"),
313+
WebsocketExtension::from_static("baz; quux"),
314+
]);
315+
assert_eq!(extensions.len(), 2);
316+
317+
let mut iter = extensions.iter();
318+
let extension = iter.next().unwrap();
319+
assert_eq!(extension.name(), "foo");
320+
let mut params = extension.params();
321+
assert_eq!(params.next(), Some(("bar1", None)));
322+
assert_eq!(params.next(), Some(("bar2", Some("3"))));
323+
assert!(params.next().is_none());
324+
325+
let extension = iter.next().unwrap();
326+
assert_eq!(extension.name(), "baz");
327+
let mut params = extension.params();
328+
assert_eq!(params.next(), Some(("quux", None)));
329+
assert!(params.next().is_none());
330+
331+
assert!(iter.next().is_none());
332+
}
333+
334+
#[test]
335+
fn extension_try_from_str_ok() {
336+
let ext = WebsocketExtension::try_from("permessage-deflate").unwrap();
337+
assert_eq!(ext.name(), "permessage-deflate");
338+
let mut params = ext.params();
339+
assert_eq!(params.next(), None);
340+
341+
let ext =
342+
WebsocketExtension::try_from("permessage-deflate; client_max_window_bits").unwrap();
343+
assert_eq!(ext.name(), "permessage-deflate");
344+
let mut params = ext.params();
345+
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
346+
assert_eq!(params.next(), None);
347+
348+
let ext =
349+
WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=10").unwrap();
350+
assert_eq!(ext.name(), "permessage-deflate");
351+
let mut params = ext.params();
352+
assert_eq!(params.next(), Some(("server_max_window_bits", Some("10"))));
353+
assert_eq!(params.next(), None);
354+
355+
let ext = WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=\"10\"")
356+
.unwrap();
357+
assert_eq!(ext.name(), "permessage-deflate");
358+
let mut params = ext.params();
359+
assert_eq!(params.next(), Some(("server_max_window_bits", Some("10"))));
360+
assert_eq!(params.next(), None);
361+
}
362+
363+
#[test]
364+
fn extension_try_from_str_err() {
365+
assert!(WebsocketExtension::try_from("").is_err());
366+
// Only single extension is allowed
367+
assert!(WebsocketExtension::try_from("permessage-deflate, permessage-snappy").is_err());
368+
}
369+
}

src/util/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use HeaderValue;
33
//pub use self::charset::Charset;
44
//pub use self::encoding::Encoding;
55
pub(crate) use self::entity::{EntityTag, EntityTagRange};
6-
pub(crate) use self::flat_csv::{FlatCsv, SemiColon};
6+
pub(crate) use self::flat_csv::{Comma, FlatCsv, SemiColon};
77
pub(crate) use self::fmt::fmt;
88
pub(crate) use self::http_date::HttpDate;
99
pub(crate) use self::iter::IterExt;

0 commit comments

Comments
 (0)