Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions engine/src/workers/stream/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ async fn ws_handler(
) -> impl IntoResponse {
let module = module.clone();

let has_authorization_protocol = headers
.get_all("sec-websocket-protocol")
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(','))
.any(|p| p.trim() == "Authorization");

let ws = if has_authorization_protocol {
ws.protocols(["Authorization"])
} else {
ws
};

if let Some(auth_function) = module.auth_function.clone() {
let engine = module.engine.clone();
let input = StreamAuthInput {
Expand Down Expand Up @@ -1993,4 +2006,136 @@ mod tests {
assert!(message.contains(&format!("127.0.0.1:{port}")));
assert!(message.contains("already in use"));
}

// ── Sec-WebSocket-Protocol echo tests ───────────────────────────────────

async fn start_ws_test_server() -> u16 {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();

crate::workers::observability::metrics::ensure_default_meter();
let engine = Arc::new(Engine::new());
let adapter: Arc<dyn StreamAdapter> = Arc::new(
crate::workers::stream::adapters::kv_store::BuiltinKvStoreAdapter::new(None),
);
let config = StreamModuleConfig {
port: 0,
host: "127.0.0.1".to_string(),
auth_function: None,
adapter: Some(crate::workers::traits::AdapterEntry {
name: "kv".to_string(),
config: None,
}),
};
let worker = Arc::new(StreamWorker::build(engine.clone(), config, adapter.clone()));
let mgr = Arc::new(StreamSocketManager::new(
engine,
adapter,
worker.clone(),
None,
worker.triggers.clone(),
));
let app = Router::new()
.route("/", get(ws_handler))
.with_state(mgr)
.into_make_service_with_connect_info::<std::net::SocketAddr>();

tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
tokio::task::yield_now().await;
port
}

// Does a raw HTTP WebSocket upgrade and returns (status_code, Sec-WebSocket-Protocol echo).
// Using raw TCP instead of tokio_tungstenite because the tungstenite client rejects any
// handshake where the server doesn't echo at least one of the offered protocols, which would
// make it impossible to write negative-case tests.
async fn raw_ws_handshake(port: u16, protocols: &[&str]) -> (u16, Option<String>) {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};

let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}"))
.await
.unwrap();
let (reader, mut writer) = tokio::io::split(stream);

let proto_lines: String = protocols
.iter()
.map(|p| format!("Sec-WebSocket-Protocol: {p}\r\n"))
.collect();
let request = format!(
"GET / HTTP/1.1\r\n\
Host: localhost\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
{proto_lines}\
\r\n"
);
writer.write_all(request.as_bytes()).await.unwrap();

let mut lines = BufReader::new(reader).lines();
let status_line = lines.next_line().await.unwrap().unwrap_or_default();
let status: u16 = status_line
.split_whitespace()
.nth(1)
.unwrap_or("0")
.parse()
.unwrap_or(0);

let mut echoed_protocol = None;
while let Some(line) = lines.next_line().await.unwrap() {
if line.is_empty() {
break;
}
if line.to_lowercase().starts_with("sec-websocket-protocol:") {
echoed_protocol = line
.splitn(2, ':')
.nth(1)
.map(|v| v.trim().to_string());
}
}
(status, echoed_protocol)
}

#[tokio::test]
async fn test_ws_authorization_protocol_single_header_is_echoed() {
let port = start_ws_test_server().await;
let (status, proto) = raw_ws_handshake(port, &["Authorization"]).await;
assert_eq!(status, 101);
assert_eq!(proto.as_deref(), Some("Authorization"));
}

#[tokio::test]
async fn test_ws_non_authorization_protocol_not_echoed() {
let port = start_ws_test_server().await;
let (status, proto) = raw_ws_handshake(port, &["graphql-ws"]).await;
assert_eq!(status, 101);
assert!(proto.is_none());
}

#[tokio::test]
async fn test_ws_no_protocol_header_not_echoed() {
let port = start_ws_test_server().await;
let (status, proto) = raw_ws_handshake(port, &[]).await;
assert_eq!(status, 101);
assert!(proto.is_none());
}

#[tokio::test]
async fn test_ws_authorization_protocol_comma_separated_last_is_echoed() {
let port = start_ws_test_server().await;
let (status, proto) = raw_ws_handshake(port, &["graphql-ws, Authorization"]).await;
assert_eq!(status, 101);
assert_eq!(proto.as_deref(), Some("Authorization"));
}

// axum reads only the first Sec-WebSocket-Protocol header for protocol selection, so the
// 101 response won't echo Authorization when it's in the second header — but the connection
// is established, which is what matters (the browser won't close the socket).
#[tokio::test]
async fn test_ws_authorization_protocol_in_second_header_connection_succeeds() {
let port = start_ws_test_server().await;
let (status, _) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await;
assert_eq!(status, 101);
}
Comment on lines +2132 to +2140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cd engine && find . -name "stream.rs" -type f | head -5

Repository: iii-hq/iii

Length of output: 84


🏁 Script executed:

# Get line count to understand file size
wc -l engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 94


🏁 Script executed:

# Read the test function and surrounding context (lines 2120-2160)
sed -n '2120,2160p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1025


🏁 Script executed:

# Read the detection logic at lines 77-82
sed -n '77,82p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 281


🏁 Script executed:

# Find raw_ws_handshake function definition
rg -n "fn raw_ws_handshake" engine/src/workers/stream/stream.rs -A 15

Repository: iii-hq/iii

Length of output: 764


🏁 Script executed:

# Check what raw_ws_handshake returns
rg -n "raw_ws_handshake" engine/src/workers/stream/stream.rs | head -20

Repository: iii-hq/iii

Length of output: 575


🏁 Script executed:

# Read more context around the detection logic - what function/scope is it in?
sed -n '70,95p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 864


🏁 Script executed:

# Read the full raw_ws_handshake function to understand protocol extraction
sed -n '2052,2090p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1502


🏁 Script executed:

# Search for ws.protocols calls to understand how protocols are set
rg -n "ws\.protocols" engine/src/workers/stream/stream.rs -B 2 -A 2

Repository: iii-hq/iii

Length of output: 178


🏁 Script executed:

# Read the context around where ws.protocols is called
sed -n '80,110p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1325


Add assertion to verify Authorization protocol is echoed in the 101 response.

The detection logic at lines 77–82 correctly finds Authorization in the second header via get_all() and flat_map(), which triggers ws.protocols(["Authorization"]) at line 85. This means the server will echo the protocol in the 101 response, matching the behavior of the comma-separated test (test_ws_authorization_protocol_comma_separated_last_is_echoed).

The inline comment suggesting the echo won't happen is misleading — the code already handles this case. The test should assert the echo to align with the comma-separated variant and lock in the correct behavior:

Suggested assertion
-        let (status, _) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await;
+        let (status, proto) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await;
         assert_eq!(status, 101);
+        assert_eq!(proto.as_deref(), Some("Authorization"));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@engine/src/workers/stream/stream.rs` around lines 2132 - 2140, The test
test_ws_authorization_protocol_in_second_header_connection_succeeds currently
only asserts the 101 status but should also assert that the server echoes the
"Authorization" protocol; update the test (which uses raw_ws_handshake and the
ws.protocols selection behavior) to capture the response headers/body from
raw_ws_handshake and add an assertion verifying that the Sec-WebSocket-Protocol
in the 101 response equals "Authorization" (matching the comma-separated variant
and the ws.protocols(["Authorization"]) behavior).

}
Loading