Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions engine/src/workers/stream/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ async fn ws_handler(
) -> impl IntoResponse {
let module = module.clone();

let has_authorization_protocol = headers
.get_all("sec-websocket-protocol")
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(','))
.any(|p| p.trim() == "Authorization");

let ws = if has_authorization_protocol {
ws.protocols(["Authorization"])
} else {
ws
};

if let Some(auth_function) = module.auth_function.clone() {
let engine = module.engine.clone();
let input = StreamAuthInput {
Expand Down Expand Up @@ -1993,4 +2006,135 @@ mod tests {
assert!(message.contains(&format!("127.0.0.1:{port}")));
assert!(message.contains("already in use"));
}

// ── Sec-WebSocket-Protocol echo tests ───────────────────────────────────

fn build_ws_test_router() -> axum::Router {
use axum::extract::connect_info::MockConnectInfo;
crate::workers::observability::metrics::ensure_default_meter();
let engine = Arc::new(Engine::new());
let adapter: Arc<dyn StreamAdapter> = Arc::new(
crate::workers::stream::adapters::kv_store::BuiltinKvStoreAdapter::new(None),
);
let config = StreamModuleConfig {
port: 0,
host: "127.0.0.1".to_string(),
auth_function: None,
adapter: Some(crate::workers::traits::AdapterEntry {
name: "kv".to_string(),
config: None,
}),
};
let worker = Arc::new(StreamWorker::build(engine.clone(), config, adapter.clone()));
let mgr = Arc::new(StreamSocketManager::new(
engine,
adapter,
worker.clone(),
None,
worker.triggers.clone(),
));
Router::new()
.route("/", get(ws_handler))
.with_state(mgr)
.layer(MockConnectInfo(std::net::SocketAddr::from(([127, 0, 0, 1], 0))))
}

fn ws_upgrade_request(protocol_headers: &[&str]) -> axum::http::Request<axum::body::Body> {
let mut builder = axum::http::Request::builder()
.uri("/")
.header("host", "localhost")
.header("connection", "Upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("sec-websocket-version", "13");
for proto in protocol_headers {
builder = builder.header("sec-websocket-protocol", *proto);
}
builder.body(axum::body::Body::empty()).unwrap()
}

#[tokio::test]
async fn test_ws_authorization_protocol_single_header_is_echoed() {
use axum::http::StatusCode;
use tower::ServiceExt;

let resp = build_ws_test_router()
.oneshot(ws_upgrade_request(&["Authorization"]))
.await
.unwrap();

assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
assert_eq!(
resp.headers()
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok()),
Some("Authorization"),
);
}

#[tokio::test]
async fn test_ws_non_authorization_protocol_not_echoed() {
use axum::http::StatusCode;
use tower::ServiceExt;

let resp = build_ws_test_router()
.oneshot(ws_upgrade_request(&["graphql-ws"]))
.await
.unwrap();

assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
assert!(
resp.headers().get("sec-websocket-protocol").is_none(),
"Sec-WebSocket-Protocol must not be echoed for non-Authorization protocols",
);
}

#[tokio::test]
async fn test_ws_no_protocol_header_not_echoed() {
use axum::http::StatusCode;
use tower::ServiceExt;

let resp = build_ws_test_router()
.oneshot(ws_upgrade_request(&[]))
.await
.unwrap();

assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
assert!(resp.headers().get("sec-websocket-protocol").is_none());
}

#[tokio::test]
async fn test_ws_authorization_protocol_comma_separated_last_is_echoed() {
use axum::http::StatusCode;
use tower::ServiceExt;

let resp = build_ws_test_router()
.oneshot(ws_upgrade_request(&["graphql-ws, Authorization"]))
.await
.unwrap();

assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
assert_eq!(
resp.headers()
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok()),
Some("Authorization"),
);
}

// axum reads only the first Sec-WebSocket-Protocol header for protocol selection, so the
// 101 response won't echo Authorization when it's in the second header — but the connection
// is established, which is what matters (the browser won't close the socket).
#[tokio::test]
async fn test_ws_authorization_protocol_in_second_header_connection_succeeds() {
use axum::http::StatusCode;
use tower::ServiceExt;

let resp = build_ws_test_router()
.oneshot(ws_upgrade_request(&["graphql-ws", "Authorization"]))
.await
.unwrap();

assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
}
Comment on lines +2132 to +2140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cd engine && find . -name "stream.rs" -type f | head -5

Repository: iii-hq/iii

Length of output: 84


🏁 Script executed:

# Get line count to understand file size
wc -l engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 94


🏁 Script executed:

# Read the test function and surrounding context (lines 2120-2160)
sed -n '2120,2160p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1025


🏁 Script executed:

# Read the detection logic at lines 77-82
sed -n '77,82p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 281


🏁 Script executed:

# Find raw_ws_handshake function definition
rg -n "fn raw_ws_handshake" engine/src/workers/stream/stream.rs -A 15

Repository: iii-hq/iii

Length of output: 764


🏁 Script executed:

# Check what raw_ws_handshake returns
rg -n "raw_ws_handshake" engine/src/workers/stream/stream.rs | head -20

Repository: iii-hq/iii

Length of output: 575


🏁 Script executed:

# Read more context around the detection logic - what function/scope is it in?
sed -n '70,95p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 864


🏁 Script executed:

# Read the full raw_ws_handshake function to understand protocol extraction
sed -n '2052,2090p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1502


🏁 Script executed:

# Search for ws.protocols calls to understand how protocols are set
rg -n "ws\.protocols" engine/src/workers/stream/stream.rs -B 2 -A 2

Repository: iii-hq/iii

Length of output: 178


🏁 Script executed:

# Read the context around where ws.protocols is called
sed -n '80,110p' engine/src/workers/stream/stream.rs

Repository: iii-hq/iii

Length of output: 1325


Add assertion to verify Authorization protocol is echoed in the 101 response.

The detection logic at lines 77–82 correctly finds Authorization in the second header via get_all() and flat_map(), which triggers ws.protocols(["Authorization"]) at line 85. This means the server will echo the protocol in the 101 response, matching the behavior of the comma-separated test (test_ws_authorization_protocol_comma_separated_last_is_echoed).

The inline comment suggesting the echo won't happen is misleading — the code already handles this case. The test should assert the echo to align with the comma-separated variant and lock in the correct behavior:

Suggested assertion
-        let (status, _) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await;
+        let (status, proto) = raw_ws_handshake(port, &["graphql-ws", "Authorization"]).await;
         assert_eq!(status, 101);
+        assert_eq!(proto.as_deref(), Some("Authorization"));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@engine/src/workers/stream/stream.rs` around lines 2132 - 2140, The test
test_ws_authorization_protocol_in_second_header_connection_succeeds currently
only asserts the 101 status but should also assert that the server echoes the
"Authorization" protocol; update the test (which uses raw_ws_handshake and the
ws.protocols selection behavior) to capture the response headers/body from
raw_ws_handshake and add an assertion verifying that the Sec-WebSocket-Protocol
in the 101 response equals "Authorization" (matching the comma-separated variant
and the ws.protocols(["Authorization"]) behavior).

}
Loading