diff --git a/engine/src/workers/stream/stream.rs b/engine/src/workers/stream/stream.rs index 24fe476d92..c1da8b2e9b 100644 --- a/engine/src/workers/stream/stream.rs +++ b/engine/src/workers/stream/stream.rs @@ -74,6 +74,19 @@ async fn ws_handler( ) -> impl IntoResponse { let module = module.clone(); + let has_authorization_protocol = headers + .get_all("sec-websocket-protocol") + .iter() + .filter_map(|v| v.to_str().ok()) + .flat_map(|s| s.split(',')) + .any(|p| p.trim() == "Authorization"); + + let ws = if has_authorization_protocol { + ws.protocols(["Authorization"]) + } else { + ws + }; + if let Some(auth_function) = module.auth_function.clone() { let engine = module.engine.clone(); let input = StreamAuthInput { @@ -1993,4 +2006,136 @@ mod tests { assert!(message.contains(&format!("127.0.0.1:{port}"))); assert!(message.contains("already in use")); } + + // ── Sec-WebSocket-Protocol echo tests ─────────────────────────────────── + + async fn start_ws_test_server() -> u16 { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + crate::workers::observability::metrics::ensure_default_meter(); + let engine = Arc::new(Engine::new()); + let adapter: Arc = Arc::new( + crate::workers::stream::adapters::kv_store::BuiltinKvStoreAdapter::new(None), + ); + let config = StreamModuleConfig { + port: 0, + host: "127.0.0.1".to_string(), + auth_function: None, + adapter: Some(crate::workers::traits::AdapterEntry { + name: "kv".to_string(), + config: None, + }), + }; + let worker = Arc::new(StreamWorker::build(engine.clone(), config, adapter.clone())); + let mgr = Arc::new(StreamSocketManager::new( + engine, + adapter, + worker.clone(), + None, + worker.triggers.clone(), + )); + let app = Router::new() + .route("/", get(ws_handler)) + .with_state(mgr) + .into_make_service_with_connect_info::(); + + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + tokio::task::yield_now().await; + port + } + + // Does a raw HTTP WebSocket upgrade and returns (status_code, Sec-WebSocket-Protocol echo). + // Using raw TCP instead of tokio_tungstenite because the tungstenite client rejects any + // handshake where the server doesn't echo at least one of the offered protocols, which would + // make it impossible to write negative-case tests. + async fn raw_ws_handshake(port: u16, protocols: &[&str]) -> (u16, Option) { + use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")) + .await + .unwrap(); + let (reader, mut writer) = tokio::io::split(stream); + + let proto_lines: String = protocols + .iter() + .map(|p| format!("Sec-WebSocket-Protocol: {p}\r\n")) + .collect(); + let request = format!( + "GET / HTTP/1.1\r\n\ + Host: localhost\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + {proto_lines}\ + \r\n" + ); + writer.write_all(request.as_bytes()).await.unwrap(); + + let mut lines = BufReader::new(reader).lines(); + let status_line = lines.next_line().await.unwrap().unwrap_or_default(); + let status: u16 = status_line + .split_whitespace() + .nth(1) + .unwrap_or("0") + .parse() + .unwrap_or(0); + + let mut echoed_protocol = None; + while let Some(line) = lines.next_line().await.unwrap() { + if line.is_empty() { + break; + } + if line.to_lowercase().starts_with("sec-websocket-protocol:") { + echoed_protocol = line + .splitn(2, ':') + .nth(1) + .map(|v| v.trim().to_string()); + } + } + (status, echoed_protocol) + } + + #[tokio::test] + async fn test_ws_authorization_protocol_single_header_is_echoed() { + let port = start_ws_test_server().await; + let (status, proto) = raw_ws_handshake(port, &["Authorization"]).await; + assert_eq!(status, 101); + assert_eq!(proto.as_deref(), Some("Authorization")); + } + + #[tokio::test] + async fn test_ws_non_authorization_protocol_not_echoed() { + let port = start_ws_test_server().await; + let (status, proto) = raw_ws_handshake(port, &["graphql-ws"]).await; + assert_eq!(status, 101); + assert!(proto.is_none()); + } + + #[tokio::test] + async fn test_ws_no_protocol_header_not_echoed() { + let port = start_ws_test_server().await; + let (status, proto) = raw_ws_handshake(port, &[]).await; + assert_eq!(status, 101); + assert!(proto.is_none()); + } + + #[tokio::test] + async fn test_ws_authorization_protocol_comma_separated_last_is_echoed() { + let port = start_ws_test_server().await; + let (status, proto) = raw_ws_handshake(port, &["graphql-ws, Authorization"]).await; + assert_eq!(status, 101); + assert_eq!(proto.as_deref(), Some("Authorization")); + } + + // axum reads only the first Sec-WebSocket-Protocol header for protocol selection, so the + // 101 response won't echo Authorization when it's in the second header — but the connection + // is established, which is what matters (the browser won't close the socket). + #[tokio::test] + async fn test_ws_authorization_protocol_in_second_header_connection_succeeds() { + let port = start_ws_test_server().await; + let (status, _) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await; + assert_eq!(status, 101); + } }