Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 159 additions & 6 deletions src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@ use rust_decimal::Decimal;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::json;
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
sync::Mutex,
task::JoinHandle,
};
use tokio_tungstenite::{accept_async, tungstenite::Message};
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{
handshake::server::{write_response, ErrorResponse, Request, Response},
http::{HeaderValue, Method as HttpMethod, Response as HttpResponse, StatusCode},
Message,
},
};

use crate::{
types::{get_market_decimals, Market, PRICE_DECIMALS},
Expand Down Expand Up @@ -49,27 +57,170 @@ pub async fn start_ws_server(
});
}

fn build_ws_handshake_error(status: StatusCode, reason: String) -> ErrorResponse {
let mut builder = HttpResponse::builder()
.status(status)
.header("content-type", "text/plain; charset=utf-8");
if let Some(headers) = builder.headers_mut() {
headers.insert("connection", HeaderValue::from_static("close"));
}
builder
.body(Some(reason))
.unwrap_or_else(|_| ErrorResponse::new(Some("websocket handshake failed".to_string())))
}

fn handle_ws_handshake(
request: &Request,
mut response: Response,
) -> Result<Response, ErrorResponse> {
if request.method() != HttpMethod::GET {
return Err(build_ws_handshake_error(
StatusCode::METHOD_NOT_ALLOWED,
"websocket endpoint requires GET".to_string(),
));
}

let path = request.uri().path();
if path != "/" && path != "/ws" {
return Err(build_ws_handshake_error(
StatusCode::NOT_FOUND,
format!("unsupported websocket path: {path}"),
));
}

response
.headers_mut()
.insert("server", HeaderValue::from_static("gateway"));
Ok(response)
}

fn serialize_http_error_response(status: StatusCode, reason: &str) -> Option<Vec<u8>> {
let response = build_ws_handshake_error(status, reason.to_string());
let mut output = Vec::new();
if write_response(&mut output, &response).is_err() {
return None;
}
if let Some(body) = response.body() {
output.extend_from_slice(body.as_bytes());
}
Some(output)
}

async fn maybe_write_http_handshake_error(stream: &mut TcpStream) -> bool {
let mut buf = [0_u8; 4096];
let read = match stream.peek(&mut buf).await {
Ok(n) => n,
Err(_) => return false,
};
if read == 0 {
return true;
}

let request = String::from_utf8_lossy(&buf[..read]);
// Only attempt HTTP error responses when we have a full header.
if !request.contains("\r\n\r\n") {
return false;
}

let mut lines = request.split("\r\n");
let request_line = match lines.next() {
Some(line) => line,
None => return false,
};

let mut parts = request_line.split_whitespace();
let method = match parts.next() {
Some(m) => m,
None => return false,
};
let path = match parts.next() {
Some(p) => p,
None => return false,
};

// Ignore non-HTTP traffic.
if !request_line.contains("HTTP/") {
return false;
}

let lower = request.to_ascii_lowercase();
let status_reason = if method != "GET" {
Some((
StatusCode::METHOD_NOT_ALLOWED,
"websocket endpoint requires GET",
))
} else if path != "/" && path != "/ws" {
Some((StatusCode::NOT_FOUND, "unsupported websocket path"))
} else if !lower.contains("\r\nupgrade: websocket\r\n") {
Some((
StatusCode::UPGRADE_REQUIRED,
"missing Upgrade: websocket header",
))
} else if !lower.contains("connection: upgrade")
&& !lower.contains("connection: keep-alive, upgrade")
{
Some((
StatusCode::UPGRADE_REQUIRED,
"missing Connection: Upgrade header",
))
} else if !lower.contains("\r\nsec-websocket-key:") {
Some((StatusCode::BAD_REQUEST, "missing Sec-WebSocket-Key header"))
} else if !lower.contains("\r\nsec-websocket-version: 13\r\n") {
Some((
StatusCode::UPGRADE_REQUIRED,
"unsupported Sec-WebSocket-Version (expected 13)",
))
} else {
None
};

if let Some((status, reason)) = status_reason {
if let Some(response) = serialize_http_error_response(status, reason) {
let _ = stream.write_all(&response).await;
}
let _ = stream.shutdown().await;
return true;
}

false
}

async fn accept_connection(
stream: TcpStream,
mut stream: TcpStream,
ws_client: Arc<PubsubClient>,
wallet: Arc<Wallet>,
program_data: &'static ProgramData,
) {
let addr = stream.peer_addr().expect("peer address");
let ws_stream = accept_async(stream).await.expect("Ws handshake");
let addr = stream
.peer_addr()
.map(|peer| peer.to_string())
.unwrap_or_else(|_| "<unknown-peer>".to_string());
if maybe_write_http_handshake_error(&mut stream).await {
log::warn!(target: LOG_TARGET, "rejected invalid Ws handshake: {addr}");
return;
}
// Returning Err(ErrorResponse) from the callback writes the HTTP error to this stream.
let ws_stream = match accept_hdr_async(stream, handle_ws_handshake).await {
Ok(ws) => ws,
Err(err) => {
log::error!(target: LOG_TARGET, "Ws connection failed: {addr}, err={err}");
return;
}
};
info!(target: LOG_TARGET, "accepted Ws connection: {}", addr);

let (mut ws_out, mut ws_in) = ws_stream.split();
let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::<Message>(64);
let subscriptions = Arc::new(Mutex::new(HashMap::<u8, JoinHandle<()>>::default()));

// writes messages to the connection
let addr_for_writer = addr.clone();
tokio::spawn(async move {
while let Some(msg) = message_rx.recv().await {
if msg.is_close() {
let _ = ws_out.send(msg).await;
let _ = ws_out.close().await;
debug!(target: LOG_TARGET, "closing Ws connection (send half): {}", addr);
debug!(target: LOG_TARGET, "closing Ws connection (send half): {addr_for_writer}");
break;
}
ws_out.send(msg).await.expect("sent");
Expand Down Expand Up @@ -100,13 +251,15 @@ async fn accept_connection(
continue;
}
info!(target: LOG_TARGET, "subscribing to events for: {}", request.sub_account_id);
let addr_for_subscription = addr.clone();
let join_handle = tokio::spawn({
let ws_client_ref = Arc::clone(&ws_client);
let sub_account_address =
wallet.sub_account(request.sub_account_id as u16);
let subscription_map = Arc::clone(&subscriptions);
let sub_account_id = request.sub_account_id;
let message_tx = message_tx.clone();
let addr_for_subscription = addr_for_subscription.clone();

async move {
loop {
Expand Down Expand Up @@ -145,7 +298,7 @@ async fn accept_connection(
.await
.is_err()
{
warn!(target: LOG_TARGET, "failed sending Ws message: {}", addr);
warn!(target: LOG_TARGET, "failed sending Ws message: {}", addr_for_subscription);
break;
}
}
Expand Down
Loading