diff --git a/src/websocket.rs b/src/websocket.rs index 4cf9281..689ee86 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -14,11 +14,19 @@ use rust_decimal::Decimal; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::json; use tokio::{ + io::AsyncWriteExt, net::{TcpListener, TcpStream}, sync::Mutex, task::JoinHandle, }; -use tokio_tungstenite::{accept_async, tungstenite::Message}; +use tokio_tungstenite::{ + accept_hdr_async, + tungstenite::{ + handshake::server::{write_response, ErrorResponse, Request, Response}, + http::{HeaderValue, Method as HttpMethod, Response as HttpResponse, StatusCode}, + Message, + }, +}; use crate::{ types::{get_market_decimals, Market, PRICE_DECIMALS}, @@ -49,14 +57,156 @@ pub async fn start_ws_server( }); } +fn build_ws_handshake_error(status: StatusCode, reason: String) -> ErrorResponse { + let mut builder = HttpResponse::builder() + .status(status) + .header("content-type", "text/plain; charset=utf-8"); + if let Some(headers) = builder.headers_mut() { + headers.insert("connection", HeaderValue::from_static("close")); + } + builder + .body(Some(reason)) + .unwrap_or_else(|_| ErrorResponse::new(Some("websocket handshake failed".to_string()))) +} + +fn handle_ws_handshake( + request: &Request, + mut response: Response, +) -> Result { + if request.method() != HttpMethod::GET { + return Err(build_ws_handshake_error( + StatusCode::METHOD_NOT_ALLOWED, + "websocket endpoint requires GET".to_string(), + )); + } + + let path = request.uri().path(); + if path != "/" && path != "/ws" { + return Err(build_ws_handshake_error( + StatusCode::NOT_FOUND, + format!("unsupported websocket path: {path}"), + )); + } + + response + .headers_mut() + .insert("server", HeaderValue::from_static("gateway")); + Ok(response) +} + +fn serialize_http_error_response(status: StatusCode, reason: &str) -> Option> { + let response = build_ws_handshake_error(status, reason.to_string()); + let mut output = Vec::new(); + if write_response(&mut output, &response).is_err() { + return None; + } + if let Some(body) = response.body() { + output.extend_from_slice(body.as_bytes()); + } + Some(output) +} + +async fn maybe_write_http_handshake_error(stream: &mut TcpStream) -> bool { + let mut buf = [0_u8; 4096]; + let read = match stream.peek(&mut buf).await { + Ok(n) => n, + Err(_) => return false, + }; + if read == 0 { + return true; + } + + let request = String::from_utf8_lossy(&buf[..read]); + // Only attempt HTTP error responses when we have a full header. + if !request.contains("\r\n\r\n") { + return false; + } + + let mut lines = request.split("\r\n"); + let request_line = match lines.next() { + Some(line) => line, + None => return false, + }; + + let mut parts = request_line.split_whitespace(); + let method = match parts.next() { + Some(m) => m, + None => return false, + }; + let path = match parts.next() { + Some(p) => p, + None => return false, + }; + + // Ignore non-HTTP traffic. + if !request_line.contains("HTTP/") { + return false; + } + + let lower = request.to_ascii_lowercase(); + let status_reason = if method != "GET" { + Some(( + StatusCode::METHOD_NOT_ALLOWED, + "websocket endpoint requires GET", + )) + } else if path != "/" && path != "/ws" { + Some((StatusCode::NOT_FOUND, "unsupported websocket path")) + } else if !lower.contains("\r\nupgrade: websocket\r\n") { + Some(( + StatusCode::UPGRADE_REQUIRED, + "missing Upgrade: websocket header", + )) + } else if !lower.contains("connection: upgrade") + && !lower.contains("connection: keep-alive, upgrade") + { + Some(( + StatusCode::UPGRADE_REQUIRED, + "missing Connection: Upgrade header", + )) + } else if !lower.contains("\r\nsec-websocket-key:") { + Some((StatusCode::BAD_REQUEST, "missing Sec-WebSocket-Key header")) + } else if !lower.contains("\r\nsec-websocket-version: 13\r\n") { + Some(( + StatusCode::UPGRADE_REQUIRED, + "unsupported Sec-WebSocket-Version (expected 13)", + )) + } else { + None + }; + + if let Some((status, reason)) = status_reason { + if let Some(response) = serialize_http_error_response(status, reason) { + let _ = stream.write_all(&response).await; + } + let _ = stream.shutdown().await; + return true; + } + + false +} + async fn accept_connection( - stream: TcpStream, + mut stream: TcpStream, ws_client: Arc, wallet: Arc, program_data: &'static ProgramData, ) { - let addr = stream.peer_addr().expect("peer address"); - let ws_stream = accept_async(stream).await.expect("Ws handshake"); + let addr = stream + .peer_addr() + .map(|peer| peer.to_string()) + .unwrap_or_else(|_| "".to_string()); + if maybe_write_http_handshake_error(&mut stream).await { + log::warn!(target: LOG_TARGET, "rejected invalid Ws handshake: {addr}"); + return; + } + // Returning Err(ErrorResponse) from the callback writes the HTTP error to this stream. + let ws_stream = match accept_hdr_async(stream, handle_ws_handshake).await { + Ok(ws) => ws, + Err(err) => { + log::error!(target: LOG_TARGET, "Ws connection failed: {addr}, err={err}"); + return; + } + }; info!(target: LOG_TARGET, "accepted Ws connection: {}", addr); let (mut ws_out, mut ws_in) = ws_stream.split(); @@ -64,12 +214,13 @@ async fn accept_connection( let subscriptions = Arc::new(Mutex::new(HashMap::>::default())); // writes messages to the connection + let addr_for_writer = addr.clone(); tokio::spawn(async move { while let Some(msg) = message_rx.recv().await { if msg.is_close() { let _ = ws_out.send(msg).await; let _ = ws_out.close().await; - debug!(target: LOG_TARGET, "closing Ws connection (send half): {}", addr); + debug!(target: LOG_TARGET, "closing Ws connection (send half): {addr_for_writer}"); break; } ws_out.send(msg).await.expect("sent"); @@ -100,6 +251,7 @@ async fn accept_connection( continue; } info!(target: LOG_TARGET, "subscribing to events for: {}", request.sub_account_id); + let addr_for_subscription = addr.clone(); let join_handle = tokio::spawn({ let ws_client_ref = Arc::clone(&ws_client); let sub_account_address = @@ -107,6 +259,7 @@ async fn accept_connection( let subscription_map = Arc::clone(&subscriptions); let sub_account_id = request.sub_account_id; let message_tx = message_tx.clone(); + let addr_for_subscription = addr_for_subscription.clone(); async move { loop { @@ -145,7 +298,7 @@ async fn accept_connection( .await .is_err() { - warn!(target: LOG_TARGET, "failed sending Ws message: {}", addr); + warn!(target: LOG_TARGET, "failed sending Ws message: {}", addr_for_subscription); break; } }