|
| 1 | +use std::convert::TryFrom; |
| 2 | + |
| 3 | +use bytes::BytesMut; |
| 4 | +use http::header::SEC_WEBSOCKET_EXTENSIONS; |
| 5 | + |
| 6 | +use util::{Comma, FlatCsv, HeaderValueString, SemiColon}; |
| 7 | +use {Error, Header, HeaderValue}; |
| 8 | + |
| 9 | +/// `Sec-WebSocket-Extensions` header, defined in [RFC6455][RFC6455_11.3.2] |
| 10 | +/// |
| 11 | +/// The `Sec-WebSocket-Extensions` header field is used in the WebSocket |
| 12 | +/// opening handshake. It is initially sent from the client to the |
| 13 | +/// server, and then subsequently sent from the server to the client, to |
| 14 | +/// agree on a set of protocol-level extensions to use for the duration |
| 15 | +/// of the connection. |
| 16 | +/// |
| 17 | +/// ## ABNF |
| 18 | +/// |
| 19 | +/// ```text |
| 20 | +/// Sec-WebSocket-Extensions = extension-list |
| 21 | +/// extension-list = 1#extension |
| 22 | +/// extension = extension-token *( ";" extension-param ) |
| 23 | +/// extension-token = registered-token |
| 24 | +/// registered-token = token |
| 25 | +/// extension-param = token [ "=" (token | quoted-string) ] |
| 26 | +/// ``` |
| 27 | +/// |
| 28 | +/// ## Example Values |
| 29 | +/// |
| 30 | +/// * `permessage-deflate` (defined in [RFC7692][RFC7692_7]) |
| 31 | +/// * `permessage-deflate; server_max_window_bits=10` |
| 32 | +/// * `permessage-deflate; server_max_window_bits=10, permessage-deflate` |
| 33 | +/// |
| 34 | +/// ## Example |
| 35 | +/// |
| 36 | +/// ```rust |
| 37 | +/// # extern crate headers; |
| 38 | +/// use headers::SecWebsocketExtensions; |
| 39 | +/// |
| 40 | +/// let extensions = SecWebsocketExtensions::from_static("permessage-deflate"); |
| 41 | +/// ``` |
| 42 | +/// |
| 43 | +/// ## Splitting and Combining |
| 44 | +/// |
| 45 | +/// Note that `Sec-WebSocket-Extensions` may be split or combined across multiple headers. |
| 46 | +/// The following are equivalent: |
| 47 | +/// ```text |
| 48 | +/// Sec-WebSocket-Extensions: foo |
| 49 | +/// Sec-WebSocket-Extensions: bar; baz=2 |
| 50 | +/// ``` |
| 51 | +/// ```text |
| 52 | +/// Sec-WebSocket-Extensions: foo, bar; baz=2 |
| 53 | +/// ``` |
| 54 | +/// |
| 55 | +/// `SecWebsocketExtensions` splits extensions when decoding and combines them into a single |
| 56 | +/// value when encoding. |
| 57 | +/// |
| 58 | +/// [RFC6455_11.3.2]: https://tools.ietf.org/html/rfc6455#section-11.3.2 |
| 59 | +/// [RFC7692_7]: https://tools.ietf.org/html/rfc7692#section-7 |
| 60 | +#[derive(Clone, Debug, Eq, PartialEq)] |
| 61 | +pub struct SecWebsocketExtensions(pub Vec<WebsocketExtension>); |
| 62 | + |
| 63 | +impl Header for SecWebsocketExtensions { |
| 64 | + fn name() -> &'static ::HeaderName { |
| 65 | + &SEC_WEBSOCKET_EXTENSIONS |
| 66 | + } |
| 67 | + |
| 68 | + fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> { |
| 69 | + let extensions = values |
| 70 | + .cloned() |
| 71 | + .flat_map(|v| { |
| 72 | + FlatCsv::<Comma>::from(v) |
| 73 | + .iter() |
| 74 | + .map(WebsocketExtension::try_from) |
| 75 | + .collect::<Vec<_>>() |
| 76 | + }) |
| 77 | + .collect::<Result<Vec<_>, _>>()?; |
| 78 | + if extensions.is_empty() { |
| 79 | + Err(Error::invalid()) |
| 80 | + } else { |
| 81 | + Ok(SecWebsocketExtensions(extensions)) |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) { |
| 86 | + if !self.is_empty() { |
| 87 | + values.extend(std::iter::once(self.to_value())); |
| 88 | + } |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +impl SecWebsocketExtensions { |
| 93 | + /// Construct a `SecWebSocketExtensions` from a static string. |
| 94 | + /// |
| 95 | + /// ## Panic |
| 96 | + /// |
| 97 | + /// Panics if the static string is not a valid extensions valie. |
| 98 | + pub fn from_static(s: &'static str) -> Self { |
| 99 | + let value = HeaderValue::from_static(s); |
| 100 | + Self::try_from(&value).expect("valid static string") |
| 101 | + } |
| 102 | + |
| 103 | + /// Convert this `SecWebsocketExtensions` to a single `HeaderValue`. |
| 104 | + pub fn to_value(&self) -> HeaderValue { |
| 105 | + let values = self.0.iter().map(HeaderValue::from).collect::<FlatCsv>(); |
| 106 | + HeaderValue::from(&values) |
| 107 | + } |
| 108 | + |
| 109 | + /// An iterator over the `WebsocketExtension`s in `SecWebsocketExtensions` header(s). |
| 110 | + pub fn iter(&self) -> impl Iterator<Item = &WebsocketExtension> { |
| 111 | + self.0.iter() |
| 112 | + } |
| 113 | + |
| 114 | + /// Get the number of extensions. |
| 115 | + pub fn len(&self) -> usize { |
| 116 | + self.0.len() |
| 117 | + } |
| 118 | + |
| 119 | + /// Returns `true` if headers contain no extensions. |
| 120 | + pub fn is_empty(&self) -> bool { |
| 121 | + self.0.is_empty() |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +impl TryFrom<&str> for SecWebsocketExtensions { |
| 126 | + type Error = Error; |
| 127 | + |
| 128 | + fn try_from(value: &str) -> Result<Self, Self::Error> { |
| 129 | + let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?; |
| 130 | + Self::try_from(&value) |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +impl TryFrom<&HeaderValue> for SecWebsocketExtensions { |
| 135 | + type Error = Error; |
| 136 | + |
| 137 | + fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> { |
| 138 | + let mut values = std::iter::once(value); |
| 139 | + SecWebsocketExtensions::decode(&mut values) |
| 140 | + } |
| 141 | +} |
| 142 | + |
| 143 | +/// A WebSocket extension containing the name and parameters. |
| 144 | +#[derive(Clone, Debug, Eq, PartialEq)] |
| 145 | +pub struct WebsocketExtension { |
| 146 | + name: HeaderValueString, |
| 147 | + params: Vec<(HeaderValueString, Option<HeaderValueString>)>, |
| 148 | +} |
| 149 | + |
| 150 | +impl WebsocketExtension { |
| 151 | + /// Construct a `WebSocketExtension` from a static string. |
| 152 | + /// |
| 153 | + /// ## Panics |
| 154 | + /// |
| 155 | + /// This function panics if the argument is invalid. |
| 156 | + pub fn from_static(src: &'static str) -> Self { |
| 157 | + WebsocketExtension::try_from(HeaderValue::from_static(src)).expect("valid static value") |
| 158 | + } |
| 159 | + |
| 160 | + /// Get the name of the extension. |
| 161 | + pub fn name(&self) -> &str { |
| 162 | + self.name.as_str() |
| 163 | + } |
| 164 | + |
| 165 | + /// An iterator over the parameters of this extension. |
| 166 | + pub fn params(&self) -> impl Iterator<Item = (&str, Option<&str>)> { |
| 167 | + self.params |
| 168 | + .iter() |
| 169 | + .map(|(k, v)| (k.as_str(), v.as_ref().map(|v| v.as_str()))) |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +impl TryFrom<&str> for WebsocketExtension { |
| 174 | + type Error = Error; |
| 175 | + |
| 176 | + fn try_from(value: &str) -> Result<Self, Self::Error> { |
| 177 | + if value.is_empty() { |
| 178 | + Err(Error::invalid()) |
| 179 | + } else { |
| 180 | + let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?; |
| 181 | + WebsocketExtension::try_from(value) |
| 182 | + } |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +impl TryFrom<HeaderValue> for WebsocketExtension { |
| 187 | + type Error = Error; |
| 188 | + |
| 189 | + fn try_from(value: HeaderValue) -> Result<Self, Self::Error> { |
| 190 | + let csv = FlatCsv::<Comma>::from(value); |
| 191 | + // More than one extension was found |
| 192 | + if csv.iter().count() > 1 { |
| 193 | + return Err(Error::invalid()); |
| 194 | + } |
| 195 | + |
| 196 | + let params = FlatCsv::<SemiColon>::from(csv.value); |
| 197 | + let mut params_iter = params.iter(); |
| 198 | + let name = params_iter |
| 199 | + .next() |
| 200 | + .ok_or_else(Error::invalid) |
| 201 | + .and_then(HeaderValueString::from_str)?; |
| 202 | + let params = params_iter |
| 203 | + .map(|p| { |
| 204 | + let mut kv = p.splitn(2, '='); |
| 205 | + let key = kv |
| 206 | + .next() |
| 207 | + .ok_or_else(Error::invalid) |
| 208 | + .and_then(HeaderValueString::from_str)?; |
| 209 | + let val = kv |
| 210 | + .next() |
| 211 | + .map(|v| HeaderValueString::from_str(v.trim_matches('"'))) |
| 212 | + .transpose()?; |
| 213 | + Ok((key, val)) |
| 214 | + }) |
| 215 | + .collect::<Result<Vec<_>, _>>()?; |
| 216 | + Ok(WebsocketExtension { name, params }) |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +impl From<&WebsocketExtension> for HeaderValue { |
| 221 | + fn from(extension: &WebsocketExtension) -> Self { |
| 222 | + let mut buf = BytesMut::from(extension.name.as_str().as_bytes()); |
| 223 | + for (key, val) in &extension.params { |
| 224 | + buf.extend_from_slice(b"; "); |
| 225 | + buf.extend_from_slice(key.as_str().as_bytes()); |
| 226 | + if let Some(val) = val { |
| 227 | + buf.extend_from_slice(b"="); |
| 228 | + buf.extend_from_slice(val.as_str().as_bytes()); |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + HeaderValue::from_maybe_shared(buf.freeze()) |
| 233 | + .expect("semicolon separated HeaderValueStrings are valid") |
| 234 | + } |
| 235 | +} |
| 236 | + |
| 237 | +#[cfg(test)] |
| 238 | +mod tests { |
| 239 | + use super::super::{test_decode, test_encode}; |
| 240 | + use super::*; |
| 241 | + |
| 242 | + #[test] |
| 243 | + fn extensions_decode() { |
| 244 | + let extensions = |
| 245 | + test_decode::<SecWebsocketExtensions>(&["key1; val1", "key2; val2"]).unwrap(); |
| 246 | + assert_eq!(extensions.0.len(), 2); |
| 247 | + assert_eq!( |
| 248 | + extensions.0[0], |
| 249 | + WebsocketExtension::try_from("key1; val1").unwrap() |
| 250 | + ); |
| 251 | + assert_eq!( |
| 252 | + extensions.0[1], |
| 253 | + WebsocketExtension::try_from("key2; val2").unwrap() |
| 254 | + ); |
| 255 | + |
| 256 | + assert_eq!(test_decode::<SecWebsocketExtensions>(&[""]), None); |
| 257 | + } |
| 258 | + |
| 259 | + #[test] |
| 260 | + fn extensions_decode_split() { |
| 261 | + // Split each extension into separate headers |
| 262 | + let extensions = |
| 263 | + test_decode::<SecWebsocketExtensions>(&["key1; val1, key2; val2", "key3; val3"]) |
| 264 | + .unwrap(); |
| 265 | + assert_eq!(extensions.0.len(), 3); |
| 266 | + assert_eq!( |
| 267 | + extensions.0[0], |
| 268 | + WebsocketExtension::try_from("key1; val1").unwrap() |
| 269 | + ); |
| 270 | + assert_eq!( |
| 271 | + extensions.0[1], |
| 272 | + WebsocketExtension::try_from("key2; val2").unwrap() |
| 273 | + ); |
| 274 | + assert_eq!( |
| 275 | + extensions.0[2], |
| 276 | + WebsocketExtension::try_from("key3; val3").unwrap() |
| 277 | + ); |
| 278 | + } |
| 279 | + |
| 280 | + #[test] |
| 281 | + fn extensions_encode() { |
| 282 | + let extensions = |
| 283 | + SecWebsocketExtensions(vec![WebsocketExtension::from_static("foo; bar; baz=1")]); |
| 284 | + let headers = test_encode(extensions); |
| 285 | + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); |
| 286 | + assert_eq!(vals.next().unwrap(), "foo; bar; baz=1"); |
| 287 | + assert_eq!(vals.next(), None); |
| 288 | + |
| 289 | + let extensions = SecWebsocketExtensions(vec![]); |
| 290 | + let headers = test_encode(extensions); |
| 291 | + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); |
| 292 | + assert_eq!(vals.next(), None); |
| 293 | + } |
| 294 | + |
| 295 | + #[test] |
| 296 | + fn extensions_encode_combine() { |
| 297 | + // Multiple extensions are combined into a single header |
| 298 | + let extensions = SecWebsocketExtensions(vec![ |
| 299 | + WebsocketExtension::from_static("foo1; bar"), |
| 300 | + WebsocketExtension::from_static("foo2; bar"), |
| 301 | + WebsocketExtension::from_static("baz; quux"), |
| 302 | + ]); |
| 303 | + let headers = test_encode(extensions); |
| 304 | + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); |
| 305 | + assert_eq!(vals.next().unwrap(), "foo1; bar, foo2; bar, baz; quux"); |
| 306 | + assert_eq!(vals.next(), None); |
| 307 | + } |
| 308 | + |
| 309 | + #[test] |
| 310 | + fn extensions_iter() { |
| 311 | + let extensions = SecWebsocketExtensions(vec![ |
| 312 | + WebsocketExtension::from_static("foo; bar1; bar2=3"), |
| 313 | + WebsocketExtension::from_static("baz; quux"), |
| 314 | + ]); |
| 315 | + assert_eq!(extensions.len(), 2); |
| 316 | + |
| 317 | + let mut iter = extensions.iter(); |
| 318 | + let extension = iter.next().unwrap(); |
| 319 | + assert_eq!(extension.name(), "foo"); |
| 320 | + let mut params = extension.params(); |
| 321 | + assert_eq!(params.next(), Some(("bar1", None))); |
| 322 | + assert_eq!(params.next(), Some(("bar2", Some("3")))); |
| 323 | + assert!(params.next().is_none()); |
| 324 | + |
| 325 | + let extension = iter.next().unwrap(); |
| 326 | + assert_eq!(extension.name(), "baz"); |
| 327 | + let mut params = extension.params(); |
| 328 | + assert_eq!(params.next(), Some(("quux", None))); |
| 329 | + assert!(params.next().is_none()); |
| 330 | + |
| 331 | + assert!(iter.next().is_none()); |
| 332 | + } |
| 333 | + |
| 334 | + #[test] |
| 335 | + fn extension_try_from_str_ok() { |
| 336 | + let ext = WebsocketExtension::try_from("permessage-deflate").unwrap(); |
| 337 | + assert_eq!(ext.name(), "permessage-deflate"); |
| 338 | + let mut params = ext.params(); |
| 339 | + assert_eq!(params.next(), None); |
| 340 | + |
| 341 | + let ext = |
| 342 | + WebsocketExtension::try_from("permessage-deflate; client_max_window_bits").unwrap(); |
| 343 | + assert_eq!(ext.name(), "permessage-deflate"); |
| 344 | + let mut params = ext.params(); |
| 345 | + assert_eq!(params.next(), Some(("client_max_window_bits", None))); |
| 346 | + assert_eq!(params.next(), None); |
| 347 | + |
| 348 | + let ext = |
| 349 | + WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=10").unwrap(); |
| 350 | + assert_eq!(ext.name(), "permessage-deflate"); |
| 351 | + let mut params = ext.params(); |
| 352 | + assert_eq!(params.next(), Some(("server_max_window_bits", Some("10")))); |
| 353 | + assert_eq!(params.next(), None); |
| 354 | + |
| 355 | + let ext = WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=\"10\"") |
| 356 | + .unwrap(); |
| 357 | + assert_eq!(ext.name(), "permessage-deflate"); |
| 358 | + let mut params = ext.params(); |
| 359 | + assert_eq!(params.next(), Some(("server_max_window_bits", Some("10")))); |
| 360 | + assert_eq!(params.next(), None); |
| 361 | + } |
| 362 | + |
| 363 | + #[test] |
| 364 | + fn extension_try_from_str_err() { |
| 365 | + assert!(WebsocketExtension::try_from("").is_err()); |
| 366 | + // Only single extension is allowed |
| 367 | + assert!(WebsocketExtension::try_from("permessage-deflate, permessage-snappy").is_err()); |
| 368 | + } |
| 369 | +} |
0 commit comments