diff --git a/crates/flashblocks-rpc/src/cache.rs b/crates/flashblocks-rpc/src/cache.rs index f3068764..49d94973 100644 --- a/crates/flashblocks-rpc/src/cache.rs +++ b/crates/flashblocks-rpc/src/cache.rs @@ -17,9 +17,8 @@ use reth_primitives::Recovered; use reth_primitives_traits::block::body::BlockBody; use reth_rpc_eth_api::{RpcBlock, RpcReceipt}; -use rollup_boost::{ - FlashblockBuilder, FlashblocksPayloadV1, OpExecutionPayloadEnvelope, PayloadVersion, -}; +use rollup_boost::provider::FlashblockBuilder; +use rollup_boost::{FlashblocksPayloadV1, OpExecutionPayloadEnvelope, PayloadVersion}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, str::FromStr, sync::Arc}; diff --git a/crates/flashblocks-rpc/src/tests/mod.rs b/crates/flashblocks-rpc/src/tests/mod.rs index cc4e99ed..66704cb0 100644 --- a/crates/flashblocks-rpc/src/tests/mod.rs +++ b/crates/flashblocks-rpc/src/tests/mod.rs @@ -176,7 +176,8 @@ mod tests { let tx2 = Bytes::from_str("0xf8cd82016d8316e5708302c01c94f39635f2adf40608255779ff742afe13de31f57780b8646e530e9700000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000001bc16d674ec8000000000000000000000000000000000000000000000000000156ddc81eed2a36d68302948ba0a608703e79b22164f74523d188a11f81c25a65dd59535bab1cd1d8b30d115f3ea07f4cfbbad77a139c9209d3bded89091867ff6b548dd714109c61d1f8e7a84d14").unwrap(); // Send another test flashblock payload - let payload = FlashblocksPayloadV1 { + + FlashblocksPayloadV1 { payload_id: PayloadId::new([0; 8]), index: 1, base: None, @@ -222,9 +223,7 @@ mod tests { }, }) .unwrap(), - }; - - payload + } } #[tokio::test] @@ -243,7 +242,7 @@ mod tests { let pending_block = provider .get_block_by_number(alloy_eips::BlockNumberOrTag::Pending) .await?; - assert_eq!(pending_block.is_none(), true); + assert!(pending_block.is_none()); let base_payload = create_first_payload(); node.send_payload(base_payload).await?; @@ -296,7 +295,7 @@ mod tests { let provider = node.provider().await?; let receipt = provider.get_transaction_receipt(TX1_HASH).await?; - assert_eq!(receipt.is_none(), true); + assert!(receipt.is_none()); node.send_test_payloads().await?; diff --git a/crates/rollup-boost/src/cli.rs b/crates/rollup-boost/src/cli.rs index 3eb4d3d6..33c3bdf4 100644 --- a/crates/rollup-boost/src/cli.rs +++ b/crates/rollup-boost/src/cli.rs @@ -1,3 +1,13 @@ +use crate::{ + BlockSelectionPolicy, FlashblocksArgs, ProxyLayer, RollupBoostServer, RpcClient, + client::rpc::{BuilderArgs, L2ClientArgs}, + debug_api::ExecutionMode, + get_version, init_metrics, + payload::PayloadSource, + probe::ProbeLayer, + provider::FlashblocksProvider, + pubsub::FlashblocksPubSubManager, +}; use alloy_rpc_types_engine::JwtSecret; use clap::Parser; use eyre::bail; @@ -8,18 +18,13 @@ use std::{ path::PathBuf, str::FromStr, sync::Arc, + time::Duration, }; -use tokio::signal::unix::{SignalKind, signal as unix_signal}; -use tracing::{Level, info}; - -use crate::{ - BlockSelectionPolicy, Flashblocks, FlashblocksArgs, ProxyLayer, RollupBoostServer, RpcClient, - client::rpc::{BuilderArgs, L2ClientArgs}, - debug_api::ExecutionMode, - get_version, init_metrics, - payload::PayloadSource, - probe::ProbeLayer, +use tokio::{ + net::TcpListener, + signal::unix::{SignalKind, signal as unix_signal}, }; +use tracing::{Level, info}; #[derive(Clone, Parser, Debug)] #[clap(author, version = get_version(), about)] @@ -140,23 +145,24 @@ impl RollupBoostArgs { let execution_mode = Arc::new(Mutex::new(self.execution_mode)); let (rpc_module, health_handle): (RpcModule<()>, _) = if self.flashblocks.flashblocks { - let flashblocks_args = self.flashblocks; - let inbound_url = flashblocks_args.flashblocks_builder_url; - let outbound_addr = SocketAddr::new( - IpAddr::from_str(&flashblocks_args.flashblocks_host)?, - flashblocks_args.flashblocks_port, + let builder_ws_url = self.flashblocks.flashblocks_builder_url; + let listener_addr = SocketAddr::new( + IpAddr::from_str(&self.flashblocks.flashblocks_host)?, + self.flashblocks.flashblocks_port, ); - let builder_client = Arc::new(Flashblocks::run( - builder_client.clone(), - inbound_url, - outbound_addr, - flashblocks_args.flashblock_builder_ws_reconnect_ms, - )?); + let listener = TcpListener::bind(listener_addr).await?; + let flashblocks_provider = Arc::new(FlashblocksProvider::new(builder_client)); + FlashblocksPubSubManager::spawn( + builder_ws_url, + listener, + flashblocks_provider.clone(), + Duration::from_millis(self.flashblocks.flashblock_builder_ws_reconnect_ms), + ); let rollup_boost = RollupBoostServer::new( l2_client, - builder_client, + flashblocks_provider, execution_mode.clone(), self.block_selection_policy, probes.clone(), diff --git a/crates/rollup-boost/src/flashblocks/inbound.rs b/crates/rollup-boost/src/flashblocks/inbound.rs deleted file mode 100644 index a5f6e053..00000000 --- a/crates/rollup-boost/src/flashblocks/inbound.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::time::Duration; - -use super::{metrics::FlashblocksWsInboundMetrics, primitives::FlashblocksPayloadV1}; -use futures::{SinkExt, StreamExt}; -use tokio::{sync::mpsc, time::interval}; -use tokio_tungstenite::{connect_async, tungstenite::Message}; -use tokio_util::sync::CancellationToken; -use tracing::{error, info}; -use url::Url; - -#[derive(Debug, thiserror::Error)] -enum FlashblocksReceiverError { - #[error("WebSocket connection failed: {0}")] - Connection(#[from] tokio_tungstenite::tungstenite::Error), - - #[error("Ping failed")] - PingFailed, - - #[error("Read timeout")] - ReadTimeout, - - #[error("Connection error: {0}")] - ConnectionError(String), - - #[error("Connection closed")] - ConnectionClosed, - - #[error("Task panicked: {0}")] - TaskPanic(String), - - #[error("Failed to send message to sender: {0}")] - SendError(#[from] Box>), -} - -pub struct FlashblocksReceiverService { - url: Url, - sender: mpsc::Sender, - reconnect_ms: u64, - metrics: FlashblocksWsInboundMetrics, -} - -impl FlashblocksReceiverService { - pub fn new(url: Url, sender: mpsc::Sender, reconnect_ms: u64) -> Self { - Self { - url, - sender, - reconnect_ms, - metrics: Default::default(), - } - } - - pub async fn run(self) { - loop { - if let Err(e) = self.connect_and_handle().await { - error!("Flashblocks receiver connection error, retrying in 5 seconds: {e}"); - self.metrics.reconnect_attempts.increment(1); - self.metrics.connection_status.set(0); - tokio::time::sleep(std::time::Duration::from_millis(self.reconnect_ms)).await; - } else { - break; - } - } - } - - async fn connect_and_handle(&self) -> Result<(), FlashblocksReceiverError> { - let (ws_stream, _) = connect_async(self.url.as_str()).await?; - let (mut write, mut read) = ws_stream.split(); - - info!("Connected to Flashblocks receiver at {}", self.url); - self.metrics.connection_status.set(1); - - let cancel_token = CancellationToken::new(); - let cancel_for_ping = cancel_token.clone(); - - let ping_task = tokio::spawn(async move { - let mut ping_interval = interval(Duration::from_millis(500)); - - loop { - tokio::select! { - _ = ping_interval.tick() => { - if write.send(Message::Ping(Default::default())).await.is_err() { - return Err(FlashblocksReceiverError::PingFailed); - } - } - _ = cancel_for_ping.cancelled() => { - tracing::debug!("Ping task cancelled"); - return Ok(()); - } - } - } - }); - - let sender = self.sender.clone(); - let metrics = self.metrics.clone(); - - let read_timeout = Duration::from_millis(1500); - let message_handle = tokio::spawn(async move { - loop { - let result = tokio::time::timeout(read_timeout, read.next()) - .await - .map_err(|_| FlashblocksReceiverError::ReadTimeout)?; - - match result { - Some(Ok(msg)) => match msg { - Message::Text(text) => { - metrics.messages_received.increment(1); - if let Ok(flashblocks_msg) = - serde_json::from_str::(&text) - { - sender.send(flashblocks_msg).await.map_err(|e| { - FlashblocksReceiverError::SendError(Box::new(e)) - })?; - } - } - Message::Close(_) => { - return Err(FlashblocksReceiverError::ConnectionClosed); - } - _ => {} - }, - Some(Err(e)) => { - return Err(FlashblocksReceiverError::ConnectionError(e.to_string())); - } - None => { - return Err(FlashblocksReceiverError::ReadTimeout); - } - }; - } - }); - - let result = tokio::select! { - result = message_handle => { - result.map_err(|e| FlashblocksReceiverError::TaskPanic(e.to_string()))? - }, - result = ping_task => { - result.map_err(|e| FlashblocksReceiverError::TaskPanic(e.to_string()))? - }, - }; - - cancel_token.cancel(); - result - } -} - -#[cfg(test)] -mod tests { - use futures::SinkExt; - use tokio::sync::watch; - use tokio_tungstenite::{accept_async, tungstenite::Utf8Bytes}; - - use super::*; - use std::net::{SocketAddr, TcpListener}; - - async fn start( - addr: SocketAddr, - ) -> eyre::Result<( - watch::Sender, - mpsc::Sender, - mpsc::Receiver<()>, - url::Url, - )> { - let (term_tx, mut term_rx) = watch::channel(false); - let (send_tx, mut send_rx) = mpsc::channel::(100); - let (send_ping_tx, send_ping_rx) = mpsc::channel::<()>(100); - - let listener = TcpListener::bind(addr)?; - let url = Url::parse(&format!("ws://{addr}"))?; - - listener - .set_nonblocking(true) - .expect("Failed to set TcpListener socket to non-blocking"); - - let listener = tokio::net::TcpListener::from_std(listener) - .expect("Failed to convert TcpListener to tokio TcpListener"); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = term_rx.changed() => { - if *term_rx.borrow() { - return; - } - } - - result = listener.accept() => { - match result { - Ok((connection, _addr)) => { - match accept_async(connection).await { - Ok(ws_stream) => { - let (mut write, mut read) = ws_stream.split(); - - loop { - tokio::select! { - Some(msg) = send_rx.recv() => { - let serialized = serde_json::to_string(&msg).unwrap(); - let utf8_bytes = Utf8Bytes::from(serialized); - - write.send(Message::Text(utf8_bytes)).await.unwrap(); - }, - msg = read.next() => { - match msg { - // we need to read for the library to handle pong messages - Some(Ok(Message::Ping(_))) => { - send_ping_tx.send(()).await.unwrap(); - }, - _ => {} - } - } - _ = term_rx.changed() => { - if *term_rx.borrow() { - return; - } - } - } - } - } - Err(e) => { - eprintln!("Failed to accept WebSocket connection: {}", e); - } - } - } - Err(e) => { - // Optionally break or continue based on error type - if e.kind() == std::io::ErrorKind::Interrupted { - break; - } - } - } - } - } - } - }); - - Ok((term_tx, send_tx, send_ping_rx, url)) - } - - #[tokio::test] - async fn test_flashblocks_receiver_service() -> eyre::Result<()> { - let addr = "127.0.0.1:8080".parse::().unwrap(); - let (term, send_msg, _, url) = start(addr).await?; - - let (tx, mut rx) = mpsc::channel(100); - - let service = FlashblocksReceiverService::new(url, tx, 100); - let _ = tokio::spawn(async move { - service.run().await; - }); - - // Send a message to the websocket server - send_msg - .send(FlashblocksPayloadV1::default()) - .await - .expect("Failed to send message"); - - let msg = rx.recv().await.expect("Failed to receive message"); - assert_eq!(msg, FlashblocksPayloadV1::default()); - - // Drop the websocket server and start another one with the same address - // The FlashblocksReceiverService should reconnect to the new server - term.send(true).unwrap(); - - // sleep for 1 second to ensure the server is dropped - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - // start a new server with the same address - let (term, send_msg, _, _url) = start(addr).await?; - send_msg - .send(FlashblocksPayloadV1::default()) - .await - .expect("Failed to send message"); - - let msg = rx.recv().await.expect("Failed to receive message"); - assert_eq!(msg, FlashblocksPayloadV1::default()); - term.send(true).unwrap(); - - Ok(()) - } - - #[tokio::test] - async fn test_flashblocks_receiver_service_ping_pong() -> eyre::Result<()> { - // test that if the builder is not sending any messages back, the service will send - // ping messages to test the connection periodically - - let addr = "127.0.0.1:8081".parse::().unwrap(); - let (_term, _send_msg, mut ping_rx, url) = start(addr).await?; - - let (tx, _rx) = mpsc::channel(100); - let service = FlashblocksReceiverService::new(url, tx, 100); - let _ = tokio::spawn(async move { - service.run().await; - }); - - // even if we do not send any messages, we should receive pings to keep the connection alive - for _ in 0..10 { - ping_rx.recv().await.expect("Failed to receive ping"); - } - - Ok(()) - } -} diff --git a/crates/rollup-boost/src/flashblocks/launcher.rs b/crates/rollup-boost/src/flashblocks/launcher.rs deleted file mode 100644 index fd8f0fb9..00000000 --- a/crates/rollup-boost/src/flashblocks/launcher.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::flashblocks::inbound::FlashblocksReceiverService; -use crate::{FlashblocksService, RpcClient}; -use core::net::SocketAddr; -use tokio::sync::mpsc; -use url::Url; - -pub struct Flashblocks {} - -impl Flashblocks { - pub fn run( - builder_url: RpcClient, - flashblocks_url: Url, - outbound_addr: SocketAddr, - reconnect_ms: u64, - ) -> eyre::Result { - let (tx, rx) = mpsc::channel(100); - - let receiver = FlashblocksReceiverService::new(flashblocks_url, tx, reconnect_ms); - tokio::spawn(async move { - let _ = receiver.run().await; - }); - - let service = FlashblocksService::new(builder_url, outbound_addr)?; - let mut service_handle = service.clone(); - tokio::spawn(async move { - service_handle.run(rx).await; - }); - - Ok(service) - } -} diff --git a/crates/rollup-boost/src/flashblocks/metrics.rs b/crates/rollup-boost/src/flashblocks/metrics.rs index e1cf3508..4c97658d 100644 --- a/crates/rollup-boost/src/flashblocks/metrics.rs +++ b/crates/rollup-boost/src/flashblocks/metrics.rs @@ -2,8 +2,8 @@ use metrics::{Counter, Gauge, Histogram}; use metrics_derive::Metrics; #[derive(Metrics, Clone)] -#[metrics(scope = "flashblocks.ws_inbound")] -pub struct FlashblocksWsInboundMetrics { +#[metrics(scope = "flashblocks.subscriber")] +pub struct FlashblocksSubscriberMetrics { /// Total number of WebSocket reconnection attempts #[metric(describe = "Total number of WebSocket reconnection attempts")] pub reconnect_attempts: Counter, @@ -14,20 +14,17 @@ pub struct FlashblocksWsInboundMetrics { #[metric(describe = "Number of flashblock messages received from builder")] pub messages_received: Counter, -} -#[derive(Metrics, Clone)] -#[metrics(scope = "flashblocks.service")] -pub struct FlashblocksServiceMetrics { #[metric(describe = "Number of errors when extending payload")] pub extend_payload_errors: Counter, #[metric(describe = "Number of times the current payload ID has been set")] pub current_payload_id_mismatch: Counter, +} - #[metric(describe = "Number of messages processed by the service")] - pub messages_processed: Counter, - +#[derive(Metrics, Clone)] +#[metrics(scope = "flashblocks.provider")] +pub struct FlashblocksProviderMetrics { #[metric(describe = "Number of flashblocks used to build a block")] pub flashblocks_used: Histogram, } diff --git a/crates/rollup-boost/src/flashblocks/mod.rs b/crates/rollup-boost/src/flashblocks/mod.rs index 3cdd3e7e..62674db8 100644 --- a/crates/rollup-boost/src/flashblocks/mod.rs +++ b/crates/rollup-boost/src/flashblocks/mod.rs @@ -1,15 +1,8 @@ -mod launcher; - -pub use launcher::*; +pub mod provider; +pub mod pubsub; mod primitives; -mod service; - pub use primitives::*; -pub use service::*; - -mod inbound; -mod outbound; mod args; pub use args::*; diff --git a/crates/rollup-boost/src/flashblocks/outbound.rs b/crates/rollup-boost/src/flashblocks/outbound.rs deleted file mode 100644 index e1b2c982..00000000 --- a/crates/rollup-boost/src/flashblocks/outbound.rs +++ /dev/null @@ -1,227 +0,0 @@ -use super::primitives::FlashblocksPayloadV1; -use core::{ - fmt::{Debug, Formatter}, - net::SocketAddr, - pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, - task::{Context, Poll}, -}; -use futures::{Sink, SinkExt}; -use std::{io, net::TcpListener, sync::Arc}; -use tokio::{ - net::TcpStream, - sync::{ - broadcast::{self, Receiver, error::RecvError}, - watch, - }, -}; -use tokio_tungstenite::WebSocketStream; -use tokio_tungstenite::tungstenite::Utf8Bytes; -use tokio_tungstenite::{accept_async, tungstenite::Message}; - -/// A WebSockets publisher that accepts connections from client websockets and broadcasts to them -/// updates about new flashblocks. It maintains a count of sent messages and active subscriptions. -/// -/// This is modelled as a `futures::Sink` that can be used to send `FlashblocksPayloadV1` messages. -pub struct WebSocketPublisher { - sent: Arc, - subs: Arc, - term: watch::Sender, - pipe: broadcast::Sender, -} - -impl WebSocketPublisher { - pub fn new(addr: SocketAddr) -> io::Result { - let (pipe, _) = broadcast::channel(100); - let (term, _) = watch::channel(false); - - let sent = Arc::new(AtomicUsize::new(0)); - let subs = Arc::new(AtomicUsize::new(0)); - let listener = TcpListener::bind(addr)?; - - tokio::spawn(listener_loop( - listener, - pipe.subscribe(), - term.subscribe(), - Arc::clone(&sent), - Arc::clone(&subs), - )); - - Ok(Self { - sent, - subs, - term, - pipe, - }) - } - - pub fn publish(&self, payload: &FlashblocksPayloadV1) -> io::Result<()> { - // Serialize the payload to a UTF-8 string - // serialize only once, then just copy around only a pointer - // to the serialized data for each subscription. - let serialized = serde_json::to_string(payload)?; - let utf8_bytes = Utf8Bytes::from(serialized); - - // Send the serialized payload to all subscribers - self.pipe - .send(utf8_bytes) - .map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e))?; - Ok(()) - } -} - -impl Drop for WebSocketPublisher { - fn drop(&mut self) { - // Notify the listener loop to terminate - let _ = self.term.send(true); - tracing::info!("WebSocketPublisher dropped, terminating listener loop"); - } -} - -async fn listener_loop( - listener: TcpListener, - receiver: Receiver, - term: watch::Receiver, - sent: Arc, - subs: Arc, -) { - listener - .set_nonblocking(true) - .expect("Failed to set TcpListener socket to non-blocking"); - - let listener = tokio::net::TcpListener::from_std(listener) - .expect("Failed to convert TcpListener to tokio TcpListener"); - - let listen_addr = listener - .local_addr() - .expect("Failed to get local address of listener"); - tracing::info!("Flashblocks WebSocketPublisher listening on {listen_addr}"); - - let mut term = term; - - loop { - let subs = Arc::clone(&subs); - - tokio::select! { - // drop this connection if the `WebSocketPublisher` is dropped - _ = term.changed() => { - if *term.borrow() { - return; - } - } - - // Accept new connections on the websocket listener - // when a new connection is established, spawn a dedicated task to handle - // the connection and broadcast with that connection. - Ok((connection, peer_addr)) = listener.accept() => { - let sent = Arc::clone(&sent); - let term = term.clone(); - let receiver_clone = receiver.resubscribe(); - - match accept_async(connection).await { - Ok(stream) => { - tokio::spawn(async move { - subs.fetch_add(1, Ordering::Relaxed); - tracing::debug!("WebSocket connection established with {}", peer_addr); - - // Handle the WebSocket connection in a dedicated task - broadcast_loop(stream, term, receiver_clone, sent).await; - - subs.fetch_sub(1, Ordering::Relaxed); - tracing::debug!("WebSocket connection closed for {}", peer_addr); - }); - } - Err(e) => { - tracing::warn!("Failed to accept WebSocket connection from {peer_addr}: {e}"); - } - } - } - } - } -} - -/// An instance of this loop is spawned for each connected WebSocket client. -/// It listens for broadcast updates about new flashblocks and sends them to the client. -/// It also handles termination signals to gracefully close the connection. -/// Any connectivity errors will terminate the loop, which will in turn -/// decrement the subscription count in the `WebSocketPublisher`. -async fn broadcast_loop( - stream: WebSocketStream, - term: watch::Receiver, - blocks: broadcast::Receiver, - sent: Arc, -) { - let mut term = term; - let mut blocks = blocks; - let mut stream = stream; - let Ok(peer_addr) = stream.get_ref().peer_addr() else { - return; - }; - - loop { - tokio::select! { - // Check if the publisher is terminated - _ = term.changed() => { - if *term.borrow() { - tracing::info!("WebSocketPublisher is terminating, closing broadcast loop"); - return; - } - } - - // Receive payloads from the broadcast channel - payload = blocks.recv() => match payload { - Ok(payload) => { - // Here you would typically send the payload to the WebSocket clients. - // For this example, we just increment the sent counter. - sent.fetch_add(1, Ordering::Relaxed); - - tracing::info!("Broadcasted payload: {:?}", payload); - if let Err(e) = stream.send(Message::Text(payload)).await { - tracing::debug!("Closing flashblocks subscription for {peer_addr}: {e}"); - break; // Exit the loop if sending fails - } - } - Err(RecvError::Closed) => { - tracing::debug!("Broadcast channel closed, exiting broadcast loop"); - return; - } - Err(RecvError::Lagged(_)) => { - tracing::warn!("Broadcast channel lagged, some messages were dropped"); - } - }, - } - } -} - -impl Debug for WebSocketPublisher { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - let subs = self.subs.load(Ordering::Relaxed); - let sent = self.sent.load(Ordering::Relaxed); - - f.debug_struct("WebSocketPublisher") - .field("subs", &subs) - .field("payloads_sent", &sent) - .finish() - } -} - -impl Sink<&FlashblocksPayloadV1> for WebSocketPublisher { - type Error = eyre::Report; - - fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: &FlashblocksPayloadV1) -> Result<(), Self::Error> { - self.publish(item)?; - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} diff --git a/crates/rollup-boost/src/flashblocks/provider.rs b/crates/rollup-boost/src/flashblocks/provider.rs new file mode 100644 index 00000000..f55064d5 --- /dev/null +++ b/crates/rollup-boost/src/flashblocks/provider.rs @@ -0,0 +1,472 @@ +use super::metrics::FlashblocksProviderMetrics; +use super::primitives::{ + ExecutionPayloadBaseV1, ExecutionPayloadFlashblockDeltaV1, FlashblocksPayloadV1, +}; +use crate::RpcClient; +use crate::{ClientResult, EngineApiExt, NewPayload, OpExecutionPayloadEnvelope, PayloadVersion}; +use alloy_primitives::U256; +use alloy_rpc_types_engine::{ + BlobsBundleV1, ExecutionPayloadV1, ExecutionPayloadV2, ExecutionPayloadV3, +}; +use alloy_rpc_types_engine::{ForkchoiceState, ForkchoiceUpdated, PayloadId, PayloadStatus}; +use alloy_rpc_types_eth::{Block, BlockNumberOrTag}; +use jsonrpsee::core::async_trait; +use op_alloy_rpc_types_engine::{ + OpExecutionPayloadEnvelopeV3, OpExecutionPayloadEnvelopeV4, OpExecutionPayloadV4, + OpPayloadAttributes, +}; +use parking_lot::Mutex; +use reth_optimism_payload_builder::payload_id_optimism; +use std::sync::Arc; +use thiserror::Error; +use tokio::sync::broadcast::error::RecvError; +use tracing::error; + +pub struct FlashblocksProvider { + pub payload_id: Arc>, + pub payload_builder: Arc>, + builder_client: RpcClient, + metrics: FlashblocksProviderMetrics, +} + +impl FlashblocksProvider { + pub fn new(builder_client: RpcClient) -> Self { + let payload_id = Arc::new(Mutex::new(PayloadId::default())); + let payload_builder = Arc::new(Mutex::new(FlashblockBuilder::default())); + + Self { + builder_client, + payload_id, + payload_builder, + metrics: FlashblocksProviderMetrics::default(), + } + } + + fn take_payload( + &self, + version: PayloadVersion, + payload_id: PayloadId, + ) -> Result { + // Check that we have flashblocks for correct payload + if *self.payload_id.lock() != payload_id { + // We have outdated `current_payload_id` so we should fallback to get_payload + // Clearing best_payload in here would cause situation when old `get_payload` would clear + // currently built correct flashblocks. + // This will self-heal on the next FCU. + return Err(FlashblocksError::MissingPayload); + } + // consume the best payload and reset the builder + let payload = { + let mut builder = self.payload_builder.lock(); + self.metrics + .flashblocks_used + .record(builder.flashblocks.len() as f64); + // Take payload and place new one in its place in one go to avoid double locking + std::mem::replace(&mut *builder, FlashblockBuilder::new()).into_envelope(version)? + }; + + Ok(payload) + } +} + +#[async_trait] +impl EngineApiExt for FlashblocksProvider { + async fn fork_choice_updated_v3( + &self, + fork_choice_state: ForkchoiceState, + payload_attributes: Option, + ) -> ClientResult { + // Calculate and set expected payload_id + if let Some(attr) = &payload_attributes { + let payload_id = payload_id_optimism(&fork_choice_state.head_block_hash, attr, 3); + tracing::debug!(message = "Setting current payload ID", payload_id = %payload_id); + *self.payload_id.lock() = payload_id; + } + + let resp = self + .builder_client + .fork_choice_updated_v3(fork_choice_state, payload_attributes) + .await?; + + Ok(resp) + } + + async fn new_payload(&self, new_payload: NewPayload) -> ClientResult { + self.builder_client.new_payload(new_payload).await + } + + async fn get_payload( + &self, + payload_id: PayloadId, + version: PayloadVersion, + ) -> ClientResult { + match self.take_payload(version, payload_id) { + Ok(payload) => Ok(payload), + Err(e) => { + error!("Failed to get flashblocks payload, falling back to builder: {e}"); + self.builder_client.get_payload(payload_id, version).await + } + } + } + + async fn get_block_by_number( + &self, + number: BlockNumberOrTag, + full: bool, + ) -> ClientResult { + self.builder_client.get_block_by_number(number, full).await + } +} + +#[derive(Clone, Debug, Default)] +pub struct FlashblockBuilder { + pub base: Option, + pub flashblocks: Vec, +} + +impl FlashblockBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn extend(&mut self, payload: FlashblocksPayloadV1) -> Result<(), FlashblocksError> { + tracing::debug!(message = "Extending payload", payload_id = %payload.payload_id, index = payload.index, has_base=payload.base.is_some()); + + // Validate the index is contiguous + if payload.index != self.flashblocks.len() as u64 { + return Err(FlashblocksError::InvalidIndex); + } + + // Check base payload rules + if payload.index == 0 { + if let Some(base) = payload.base { + self.base = Some(base) + } else { + return Err(FlashblocksError::MissingBasePayload); + } + } else if payload.base.is_some() { + return Err(FlashblocksError::UnexpectedBasePayload); + } + + // Update latest diff and accumulate transactions and withdrawals + self.flashblocks.push(payload.diff); + + Ok(()) + } + + pub fn into_envelope( + self, + version: PayloadVersion, + ) -> Result { + self.build_envelope(version) + } + + pub fn build_envelope( + &self, + version: PayloadVersion, + ) -> Result { + let base = self.base.as_ref().ok_or(FlashblocksError::MissingPayload)?; + + // There must be at least one delta + let diff = self + .flashblocks + .last() + .ok_or(FlashblocksError::MissingDelta)?; + + let transactions = self + .flashblocks + .iter() + .flat_map(|diff| diff.transactions.clone()) + .collect(); + + let withdrawals = self + .flashblocks + .iter() + .flat_map(|diff| diff.withdrawals.clone()) + .collect(); + + let withdrawals_root = diff.withdrawals_root; + + let execution_payload = ExecutionPayloadV3 { + blob_gas_used: 0, + excess_blob_gas: 0, + payload_inner: ExecutionPayloadV2 { + withdrawals, + payload_inner: ExecutionPayloadV1 { + parent_hash: base.parent_hash, + fee_recipient: base.fee_recipient, + state_root: diff.state_root, + receipts_root: diff.receipts_root, + logs_bloom: diff.logs_bloom, + prev_randao: base.prev_randao, + block_number: base.block_number, + gas_limit: base.gas_limit, + gas_used: diff.gas_used, + timestamp: base.timestamp, + extra_data: base.extra_data.clone(), + base_fee_per_gas: base.base_fee_per_gas, + block_hash: diff.block_hash, + transactions, + }, + }, + }; + + match version { + PayloadVersion::V3 => Ok(OpExecutionPayloadEnvelope::V3( + OpExecutionPayloadEnvelopeV3 { + parent_beacon_block_root: base.parent_beacon_block_root, + block_value: U256::ZERO, + blobs_bundle: BlobsBundleV1::default(), + should_override_builder: false, + execution_payload, + }, + )), + PayloadVersion::V4 => Ok(OpExecutionPayloadEnvelope::V4( + OpExecutionPayloadEnvelopeV4 { + parent_beacon_block_root: base.parent_beacon_block_root, + block_value: U256::ZERO, + blobs_bundle: BlobsBundleV1::default(), + should_override_builder: false, + execution_payload: OpExecutionPayloadV4 { + withdrawals_root, + payload_inner: execution_payload, + }, + execution_requests: vec![], + }, + )), + } + } +} + +#[derive(Debug, Error)] +pub enum FlashblocksError { + #[error("Missing base payload for initial flashblock")] + MissingBasePayload, + #[error("Unexpected base payload for non-initial flashblock")] + UnexpectedBasePayload, + #[error("Missing delta for flashblock")] + MissingDelta, + #[error("Invalid index for flashblock")] + InvalidIndex, + #[error("Missing payload")] + MissingPayload, + #[error(transparent)] + RecvError(#[from] RecvError), + #[error(transparent)] + SerdeJsonError(#[from] serde_json::Error), +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{PayloadSource, RpcClient}; + use alloy_primitives::B256; + use alloy_rpc_types_engine::ForkchoiceState; + use alloy_rpc_types_engine::{ForkchoiceUpdated, PayloadStatus, PayloadStatusEnum}; + use jsonrpsee::RpcModule; + use jsonrpsee::server::ServerBuilder; + use op_alloy_rpc_types_engine::OpPayloadAttributes; + use reth_optimism_payload_builder::payload_id_optimism; + use reth_rpc_layer::JwtSecret; + + #[test] + fn test_take_payload() { + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + ) + .unwrap(); + + let provider = FlashblocksProvider::new(rpc_client); + + let test_payload_id = PayloadId::new([1u8; 8]); + *provider.payload_id.lock() = test_payload_id; + + { + let mut builder = provider.payload_builder.lock(); + builder.base = Some(ExecutionPayloadBaseV1::default()); + builder.flashblocks = vec![ExecutionPayloadFlashblockDeltaV1::default()]; + } + + let result = provider.take_payload(PayloadVersion::V3, test_payload_id); + assert!(result.is_ok()); + + // Verify the builder was reset + let builder = provider.payload_builder.lock(); + assert!(builder.base.is_none()); + assert!(builder.flashblocks.is_empty()); + } + + #[test] + fn test_missing_payload() { + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + ) + .unwrap(); + + let provider = FlashblocksProvider::new(rpc_client); + + let test_payload_id = PayloadId::new([1u8; 8]); + *provider.payload_id.lock() = test_payload_id; + + { + let mut builder = provider.payload_builder.lock(); + builder.base = Some(ExecutionPayloadBaseV1::default()); + builder.flashblocks = vec![ExecutionPayloadFlashblockDeltaV1::default()]; + } + + // Test with mismatched payload ID + let wrong_payload_id = PayloadId::new([2u8; 8]); + let result = provider.take_payload(PayloadVersion::V3, wrong_payload_id); + matches!(result, Err(FlashblocksError::MissingPayload)); + } + + #[tokio::test] + async fn test_fork_choice_updated() -> eyre::Result<()> { + // Create a mock server + let server = ServerBuilder::default().build("127.0.0.1:0").await?; + let server_addr = server.local_addr()?; + + let mut module = RpcModule::new(()); + module.register_async_method( + "engine_forkchoiceUpdatedV3", + |_params, _context, _state| async move { + let response = ForkchoiceUpdated { + payload_status: PayloadStatus::from_status(PayloadStatusEnum::Valid), + payload_id: Some(PayloadId::new([1u8; 8])), + }; + ClientResult::Ok(response) + }, + )?; + + let _handle = server.start(module); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let rpc_client = RpcClient::new( + format!("http://{}", server_addr).parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = FlashblocksProvider::new(rpc_client); + + let fork_choice_state = ForkchoiceState { + head_block_hash: B256::random(), + safe_block_hash: B256::random(), + finalized_block_hash: B256::random(), + }; + let payload_attributes = OpPayloadAttributes::default(); + + let expected_payload_id = + payload_id_optimism(&fork_choice_state.head_block_hash, &payload_attributes, 3); + + let result = provider + .fork_choice_updated_v3(fork_choice_state, Some(payload_attributes)) + .await?; + + assert_eq!(result.payload_status.status, PayloadStatusEnum::Valid,); + + let payload_id = *provider.payload_id.lock(); + assert_eq!(payload_id, expected_payload_id); + + Ok(()) + } + + #[test] + fn test_extend() { + let mut builder = FlashblockBuilder::new(); + let payload_0 = FlashblocksPayloadV1 { + payload_id: PayloadId::new([1u8; 8]), + index: 0, + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let result = builder.extend(payload_0); + assert!(result.is_ok()); + assert!(builder.base.is_some()); + assert_eq!(builder.flashblocks.len(), 1); + + let payload_1 = FlashblocksPayloadV1 { + payload_id: PayloadId::new([1u8; 8]), + index: 1, + ..Default::default() + }; + + let result = builder.extend(payload_1); + assert!(result.is_ok()); + assert_eq!(builder.flashblocks.len(), 2); + } + + #[test] + fn test_extend_missing_base_payload() { + let mut builder = FlashblockBuilder::new(); + + let payload_0 = FlashblocksPayloadV1 { + payload_id: PayloadId::new([1u8; 8]), + index: 0, + base: None, + ..Default::default() + }; + + let result = builder.extend(payload_0); + matches!(result, Err(FlashblocksError::MissingBasePayload)); + } + + #[test] + fn test_extend_unexpected_base_payload() { + let mut builder = FlashblockBuilder::new(); + + let payload_0 = FlashblocksPayloadV1 { + payload_id: PayloadId::new([1u8; 8]), + index: 0, + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let result = builder.extend(payload_0); + assert!(result.is_ok()); + + let payload_1 = FlashblocksPayloadV1 { + payload_id: PayloadId::new([1u8; 8]), + index: 1, + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let result = builder.extend(payload_1); + matches!(result, Err(FlashblocksError::UnexpectedBasePayload)); + } + + #[test] + fn test_into_envelope() { + let mut builder = FlashblockBuilder::new(); + builder.base = Some(ExecutionPayloadBaseV1::default()); + builder.flashblocks = vec![ExecutionPayloadFlashblockDeltaV1::default()]; + + // Test V3 envelope creation + let result = builder.build_envelope(PayloadVersion::V3); + matches!(result.unwrap(), OpExecutionPayloadEnvelope::V3(_)); + + // Test V4 envelope creation + let result = builder.build_envelope(PayloadVersion::V4); + matches!(result.unwrap(), OpExecutionPayloadEnvelope::V4(_)); + + // Test missing payload + let empty_builder = FlashblockBuilder::new(); + let result = empty_builder.build_envelope(PayloadVersion::V3); + matches!(result, Err(FlashblocksError::MissingPayload)); + + // Test missing delta + let mut builder_no_delta = FlashblockBuilder::new(); + builder_no_delta.base = Some(ExecutionPayloadBaseV1::default()); + let result = builder_no_delta.build_envelope(PayloadVersion::V3); + matches!(result, Err(FlashblocksError::MissingDelta)); + } +} diff --git a/crates/rollup-boost/src/flashblocks/pubsub.rs b/crates/rollup-boost/src/flashblocks/pubsub.rs new file mode 100644 index 00000000..4f81d8d3 --- /dev/null +++ b/crates/rollup-boost/src/flashblocks/pubsub.rs @@ -0,0 +1,624 @@ +#![allow(clippy::result_large_err, clippy::large_enum_variant)] + +use super::FlashblocksPayloadV1; +use super::metrics::FlashblocksSubscriberMetrics; +use super::provider::FlashblocksProvider; +use futures::stream::SplitStream; +use futures::{Sink, SinkExt, StreamExt}; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::broadcast::error::RecvError; +use tokio::sync::{broadcast, watch}; +use tokio::task::{JoinError, JoinHandle}; +use tokio_tungstenite::tungstenite::{self, Message, Utf8Bytes}; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async}; +use tokio_util::bytes::Bytes; +use url::Url; + +pub struct FlashblocksPubSubManager { + pub subscriber: FlashblocksSubscriber, + pub publisher: FlashblocksPublisher, +} + +impl FlashblocksPubSubManager { + pub fn spawn( + builder_ws_endpoint: Url, + listener: TcpListener, + flashblocks_provider: Arc, + reconnect_backoff: Duration, + ) -> Self { + let (payload_tx, payload_rx) = broadcast::channel(100); + + Self { + subscriber: FlashblocksSubscriber::new( + builder_ws_endpoint, + payload_tx, + flashblocks_provider, + reconnect_backoff, + ), + publisher: FlashblocksPublisher::new(listener, payload_rx), + } + } +} + +pub struct FlashblocksSubscriber { + pub handle: JoinHandle>, +} + +impl FlashblocksSubscriber { + fn new( + builder_ws_endpoint: Url, + payload_tx: broadcast::Sender, + flashblocks_provider: Arc, + reconnect_backoff: Duration, + ) -> Self { + let payload_tx = Arc::new(payload_tx); + let metrics = FlashblocksSubscriberMetrics::default(); + + let handle = tokio::spawn(async move { + loop { + let (ws_stream, _) = match connect_async(builder_ws_endpoint.as_str()).await { + Ok(stream) => stream, + Err(e) => { + tracing::error!("Could not connect to builder ws endpoint: {e}"); + metrics.reconnect_attempts.increment(1); + metrics.connection_status.set(0); + tokio::time::sleep(reconnect_backoff).await; + continue; + } + }; + metrics.connection_status.set(1); + + let (sink, stream) = ws_stream.split(); + let (pong_tx, pong_rx) = watch::channel(Message::Pong(Bytes::default())); + + let ping_handle = spawn_ping(sink, pong_rx); + let stream_handle = FlashblocksSubscriber::handle_flashblocks_stream( + stream, + flashblocks_provider.clone(), + payload_tx.clone(), + pong_tx, + metrics.clone(), + ); + + let abort_ping = ping_handle.abort_handle(); + let abort_stream = stream_handle.abort_handle(); + + tokio::select! { + result = ping_handle => { + if let Err(e) = result.unwrap_or_else(|e| Err(e.into())) { + tracing::error!("Ping handle error: {}", e); + abort_stream.abort(); + } + tracing::warn!("Ping handle resolved early, reestabling connection"); + } + result = stream_handle => { + if let Err(e) = result.unwrap_or_else(|e| Err(e.into())) { + + tracing::error!("Flashblocks stream handle error: {}", e); + abort_ping.abort(); + } + tracing::warn!("Flashblocks stream handle resolved early, reestabling connection"); + } + } + + tokio::time::sleep(reconnect_backoff).await; + } + }); + + Self { handle } + } + + fn handle_flashblocks_stream( + mut stream: SplitStream>>, + flashblocks_provider: Arc, + payload_tx: Arc>, + pong_tx: watch::Sender, + metrics: FlashblocksSubscriberMetrics, + ) -> JoinHandle> { + tokio::spawn(async move { + while let Some(msg) = stream.next().await { + let msg = msg.map_err(|e| { + tracing::error!("Ws connection error: {e}"); + e + })?; + metrics.messages_received.increment(1); + + match msg { + Message::Text(bytes) => { + // TODO: docs + if let Ok(flashblock) = serde_json::from_str::(&bytes) + { + let local_payload_id = flashblocks_provider.payload_id.lock(); + + if *local_payload_id == flashblock.payload_id { + let mut payload_builder = + flashblocks_provider.payload_builder.lock(); + let flashblock_index = flashblock.index; + if let Err(e) = payload_builder.extend(flashblock) { + metrics.extend_payload_errors.increment(1); + tracing::error!( + target: "pubsub::handle_flashblocks_stream", + message = "Failed to extend payload", + error = %e, + payload_id = %local_payload_id, + index = flashblock_index + ); + continue; + } + } else { + metrics.current_payload_id_mismatch.increment(1); + tracing::error!( + target: "pubsub::handle_flashblocks_stream", + message = "Payload ID mismatch", + payload_id = %flashblock.payload_id, + %local_payload_id, + index = flashblock.index, + ); + continue; + } + } else { + tracing::error!( + target: "pubsub::handle_flashblocks_stream", + message = "Failed deserialize payload", + ); + continue; + } + + payload_tx.send(bytes)?; + } + Message::Pong(_) => { + pong_tx.send(Message::Pong(Bytes::default()))?; + } + Message::Close(frame) => { + tracing::warn!( + target: "pubsub::handle_flashblocks_stream", + message = "Connection closed", + code = ?frame.as_ref().map(|f| f.code), + reason = ?frame.as_ref().map(|f| f.reason.as_ref() as &str), + ); + } + other => { + tracing::warn!( + target: "pubsub::handle_flashblocks_stream", + message = format!("Unexpected message {other}") + ); + } + } + } + + Ok(()) + }) + } +} + +fn spawn_ping( + mut sink: S, + mut pong_rx: watch::Receiver, +) -> JoinHandle> +where + S: Sink + Send + Unpin + 'static, +{ + pong_rx.mark_changed(); + tokio::spawn(async move { + let mut ping_interval = tokio::time::interval(Duration::from_millis(500)); + loop { + ping_interval.tick().await; + if pong_rx.has_changed()? { + sink.send(Message::Ping(Bytes::new())) + .await + .map_err(|_| FlashblocksPubSubError::PingFailed)?; + pong_rx.mark_unchanged(); + } else { + tracing::error!("Missing pong response from builder stream"); + return Err(FlashblocksPubSubError::MissingPong); + } + } + }) +} + +pub struct FlashblocksPublisher { + pub handle: JoinHandle>, +} + +impl FlashblocksPublisher { + fn new(listener: TcpListener, publisher_rx: broadcast::Receiver) -> Self { + let handle = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((tcp_stream, _)) => { + let ws_stream = tokio_tungstenite::accept_async(tcp_stream).await?; + let rx = publisher_rx.resubscribe(); + tokio::spawn(Self::handle_connection(ws_stream, rx)); + } + + Err(e) => { + tracing::error!( + target = "flashblocks_publisher::new", + "Error when accepting connection from listener {e}" + ); + } + } + } + }); + + Self { handle } + } + + async fn handle_connection( + mut stream: WebSocketStream, + mut publisher_rx: broadcast::Receiver, + ) { + loop { + match publisher_rx.recv().await { + Ok(payload) => { + // Here you would typically do any transformation or logging. + if let Err(e) = stream.send(Message::Text(payload)).await { + // If sending fails, close the connection. + tracing::debug!("Closing flashblocks subscription: {e}"); + break; + } + } + Err(RecvError::Closed) => { + tracing::debug!("Broadcast channel closed, exiting subscription loop"); + return; + } + Err(RecvError::Lagged(skipped)) => { + tracing::warn!("Broadcast channel lagged, skipped {skipped} messages"); + } + } + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum FlashblocksPubSubError { + #[error("Ping failed")] + PingFailed, + #[error("Missing pong response")] + MissingPong, + #[error(transparent)] + ConnectError(#[from] tungstenite::Error), + #[error(transparent)] + FlashblocksPayloadSendError(#[from] broadcast::error::SendError), + #[error(transparent)] + MessageSendError(#[from] watch::error::SendError), + #[error(transparent)] + Utf8BytesSendError(#[from] broadcast::error::SendError), + #[error(transparent)] + RecvError(#[from] watch::error::RecvError), + #[error(transparent)] + JoinError(#[from] JoinError), +} + +#[cfg(test)] +mod tests { + use crate::{ + ExecutionPayloadBaseV1, FlashblocksPayloadV1, PayloadSource, RpcClient, + provider::FlashblocksProvider, + pubsub::{FlashblocksPubSubError, FlashblocksPublisher, FlashblocksSubscriber, spawn_ping}, + }; + use alloy_primitives::B256; + use alloy_rpc_types_engine::PayloadId; + use bytes::Bytes; + use futures::{SinkExt, StreamExt, sink}; + use op_alloy_rpc_types_engine::OpPayloadAttributes; + use reth_optimism_payload_builder::payload_id_optimism; + use reth_rpc_layer::JwtSecret; + use std::{sync::Arc, time::Duration}; + use tokio::sync::watch; + use tokio::{net::TcpListener, sync::broadcast}; + use tokio_tungstenite::tungstenite::Message; + use url::Url; + + pub struct MockBuilder { + handle: tokio::task::JoinHandle>, + msg_tx: tokio::sync::mpsc::UnboundedSender, + } + + impl MockBuilder { + pub async fn spawn( + listener: TcpListener, + reconnect_backoff: Duration, + ) -> eyre::Result { + let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::(); + + let handle = tokio::spawn(async move { + loop { + let (tcp, _) = listener.accept().await?; + let ws = tokio_tungstenite::accept_async(tcp).await?; + let (mut sink, mut stream) = ws.split(); + + loop { + tokio::select! { + msg = stream.next() => { + if let Message::Close(_) = msg.unwrap()? { + break; + } + } + msg = msg_rx.recv() => { + match msg.unwrap() { + Message::Close(_) => { + drop(sink); + drop(stream); + break; + } + other => { + sink.send(other).await?; + } + } + } + } + } + + tokio::time::sleep(reconnect_backoff).await; + } + }); + + Ok(Self { handle, msg_tx }) + } + + async fn send_message(&self, msg: Message) -> eyre::Result<()> { + self.msg_tx.send(msg)?; + Ok(()) + } + } + + impl Drop for MockBuilder { + fn drop(&mut self) { + self.handle.abort(); + } + } + + #[tokio::test] + async fn test_ping_pong() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let ws_endpoint = + Url::parse(&format!("ws://{}", listener.local_addr().unwrap())).expect("invalid URL"); + + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = Arc::new(FlashblocksProvider::new(rpc_client)); + let (tx, _rx) = broadcast::channel(10); + let subscriber = + FlashblocksSubscriber::new(ws_endpoint, tx, provider, Duration::from_millis(100)); + let _mock = MockBuilder::spawn(listener, Duration::from_secs(1)).await?; + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + assert!(!subscriber.handle.is_finished()); + + Ok(()) + } + + #[tokio::test] + async fn test_reconnect_stream() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let ws_endpoint = + Url::parse(&format!("ws://{}", listener.local_addr().unwrap())).expect("invalid URL"); + + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = Arc::new(FlashblocksProvider::new(rpc_client)); + let (tx, mut rx) = broadcast::channel(10); + + let _subscriber = + FlashblocksSubscriber::new(ws_endpoint, tx, provider, Duration::from_millis(100)); + let mock = MockBuilder::spawn(listener, Duration::from_secs(3)).await?; + mock.send_message(Message::Close(None)).await?; + tokio::time::sleep(Duration::from_secs(3)).await; + + // Send a flashblock after reconnect + let flashblock_payload = FlashblocksPayloadV1 { + index: 0, + payload_id: PayloadId::default(), + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let json = serde_json::to_string(&flashblock_payload)?; + let msg = Message::Text(json.into()); + mock.send_message(msg).await?; + + let payload_bytes = rx.recv().await?; + assert!(!payload_bytes.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_missing_pong() -> eyre::Result<()> { + let sink = sink::drain(); + let (_pong_tx, pong_rx) = watch::channel(Message::Pong(Bytes::default())); + let handle = spawn_ping(sink, pong_rx); + + let result = handle.await.unwrap(); + assert!(matches!(result, Err(FlashblocksPubSubError::MissingPong))); + + Ok(()) + } + + #[tokio::test] + async fn test_send_flashblock() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let ws_endpoint = + Url::parse(&format!("ws://{}", listener.local_addr().unwrap())).expect("invalid URL"); + + let mock = MockBuilder::spawn(listener, Duration::from_secs(1)).await?; + + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = Arc::new(FlashblocksProvider::new(rpc_client)); + let (tx, _rx) = broadcast::channel(10); + + let _subscriber = FlashblocksSubscriber::new( + ws_endpoint, + tx, + provider.clone(), + Duration::from_millis(100), + ); + + let flashblock_payload = FlashblocksPayloadV1 { + index: 0, + payload_id: PayloadId::default(), + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let json = serde_json::to_string(&flashblock_payload)?; + let msg = Message::Text(json.into()); + mock.send_message(msg).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + assert_eq!(provider.payload_builder.lock().flashblocks.len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn test_payload_id_mismatch() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let ws_endpoint = + Url::parse(&format!("ws://{}", listener.local_addr().unwrap())).expect("invalid URL"); + + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = Arc::new(FlashblocksProvider::new(rpc_client)); + let (tx, _rx) = broadcast::channel(10); + let _subscriber = FlashblocksSubscriber::new( + ws_endpoint, + tx, + provider.clone(), + Duration::from_millis(100), + ); + + let mock = MockBuilder::spawn(listener, Duration::from_secs(1)).await?; + + // Send flashblock with mismatched payload id + let payload_id = payload_id_optimism(&B256::random(), &OpPayloadAttributes::default(), 3); + let flashblock_payload = FlashblocksPayloadV1 { + index: 0, + payload_id, + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let json = serde_json::to_string(&flashblock_payload)?; + let msg = Message::Text(json.into()); + mock.send_message(msg).await?; + + assert_eq!(provider.payload_builder.lock().flashblocks.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_malformed_flashblocks_payload() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let ws_endpoint = + Url::parse(&format!("ws://{}", listener.local_addr().unwrap())).expect("invalid URL"); + + let rpc_client = RpcClient::new( + "http://localhost:8545".parse().unwrap(), + JwtSecret::random(), + 1000, + PayloadSource::Builder, + )?; + + let provider = Arc::new(FlashblocksProvider::new(rpc_client)); + let (tx, _rx) = broadcast::channel(10); + + let _subscriber = FlashblocksSubscriber::new( + ws_endpoint, + tx, + provider.clone(), + Duration::from_millis(100), + ); + + let mock = MockBuilder::spawn(listener, Duration::from_secs(1)).await?; + + let msg = Message::Text("0xbad".into()); + mock.send_message(msg).await?; + + assert_eq!(provider.payload_builder.lock().flashblocks.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_publish_flashblock() -> eyre::Result<()> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let publisher_addr = listener.local_addr()?; + let (tx, rx) = broadcast::channel(10); + let _publisher = FlashblocksPublisher::new(listener, rx); + + let client_stream = tokio::net::TcpStream::connect(publisher_addr).await?; + let (ws_stream, _) = + tokio_tungstenite::client_async("ws://localhost", client_stream).await?; + let (mut _sink, mut stream) = ws_stream.split(); + + let num_flashblocks = 5_usize; + let mut sent_flashblocks = vec![]; + + // Base flashblock payload + let flashblock_payload = FlashblocksPayloadV1 { + index: 0, + payload_id: PayloadId::default(), + base: Some(ExecutionPayloadBaseV1::default()), + ..Default::default() + }; + + let json = serde_json::to_string(&flashblock_payload)?; + let message_bytes = json.into(); + tx.send(message_bytes)?; + sent_flashblocks.push(flashblock_payload); + + // Send additional flashlbocks + for i in 1..num_flashblocks { + let flashblock_payload = FlashblocksPayloadV1 { + index: i as u64, + payload_id: PayloadId::default(), + base: None, + ..Default::default() + }; + + let json = serde_json::to_string(&flashblock_payload)?; + let message_bytes = json.into(); + + tx.send(message_bytes)?; + sent_flashblocks.push(flashblock_payload); + } + + for flashblock in sent_flashblocks { + let Message::Text(msg) = stream.next().await.unwrap()? else { + panic!("Unexpected message"); + }; + + let received_flashblock: FlashblocksPayloadV1 = serde_json::from_str(&msg)?; + + assert_eq!(received_flashblock.index, flashblock.index); + assert_eq!(received_flashblock.payload_id, flashblock.payload_id,); + } + + Ok(()) + } +}