diff --git a/Cargo.lock b/Cargo.lock index 3cf5728e6c..99b2ae1435 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9340,6 +9340,7 @@ dependencies = [ "reth-storage-api", "serde", "serde_json", + "test-case", "tokio", "tokio-tungstenite", "tracing", diff --git a/crates/optimism/flashblocks/Cargo.toml b/crates/optimism/flashblocks/Cargo.toml index 5f10fd2eb2..35e98783a3 100644 --- a/crates/optimism/flashblocks/Cargo.toml +++ b/crates/optimism/flashblocks/Cargo.toml @@ -46,3 +46,4 @@ tracing.workspace = true eyre.workspace = true [dev-dependencies] +test-case.workspace = true diff --git a/crates/optimism/flashblocks/src/ws/stream.rs b/crates/optimism/flashblocks/src/ws/stream.rs index 8c0601606e..c18857eee3 100644 --- a/crates/optimism/flashblocks/src/ws/stream.rs +++ b/crates/optimism/flashblocks/src/ws/stream.rs @@ -98,7 +98,7 @@ where { fn connect(&mut self) { let ws_url = self.ws_url.clone(); - let connector = self.connector.clone(); + let mut connector = self.connector.clone(); Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await })); @@ -154,7 +154,7 @@ pub trait WsConnect { /// /// See the [`WsConnect`] documentation for details. fn connect( - &self, + &mut self, ws_url: Url, ) -> impl Future> + Send + Sync; } @@ -168,9 +168,120 @@ pub struct WsConnector; impl WsConnect for WsConnector { type Stream = WssStream; - async fn connect(&self, ws_url: Url) -> eyre::Result { + async fn connect(&mut self, ws_url: Url) -> eyre::Result { let (stream, _response) = connect_async(ws_url.as_str()).await?; Ok(stream.split().1) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ExecutionPayloadBaseV1; + use alloy_primitives::bytes::Bytes; + use brotli::enc::BrotliEncoderParams; + use std::future; + use tokio_tungstenite::tungstenite::Error; + + /// A `FakeConnector` creates [`FakeStream`]. + /// + /// It simulates the websocket stream instead of connecting to a real websocket. + #[derive(Clone)] + struct FakeConnector(FakeStream); + + /// Simulates a websocket stream while using a preprogrammed set of messages instead. + #[derive(Default)] + struct FakeStream(Vec>); + + impl Clone for FakeStream { + fn clone(&self) -> Self { + Self( + self.0 + .iter() + .map(|v| match v { + Ok(msg) => Ok(msg.clone()), + Err(err) => unimplemented!("Cannot clone this error: {err}"), + }) + .collect(), + ) + } + } + + impl Stream for FakeStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + Poll::Ready(this.0.pop()) + } + } + + impl WsConnect for FakeConnector { + type Stream = FakeStream; + + fn connect( + &mut self, + _ws_url: Url, + ) -> impl Future> + Send + Sync { + future::ready(Ok(self.0.clone())) + } + } + + impl>> From for FakeConnector { + fn from(value: T) -> Self { + Self(FakeStream(value.into_iter().collect())) + } + } + + fn to_json_message(block: &FlashBlock) -> Result { + Ok(Message::Binary(Bytes::from(serde_json::to_vec(block).unwrap()))) + } + + fn to_brotli_message(block: &FlashBlock) -> Result { + let json = serde_json::to_vec(block).unwrap(); + let mut compressed = Vec::new(); + brotli::BrotliCompress( + &mut json.as_slice(), + &mut compressed, + &BrotliEncoderParams::default(), + )?; + + Ok(Message::Binary(Bytes::from(compressed))) + } + + #[test_case::test_case(to_json_message; "json")] + #[test_case::test_case(to_brotli_message; "brotli")] + #[tokio::test] + async fn test_stream_decodes_messages_successfully( + to_message: impl Fn(&FlashBlock) -> Result, + ) { + let flashblocks = [FlashBlock { + payload_id: Default::default(), + index: 0, + base: Some(ExecutionPayloadBaseV1 { + parent_beacon_block_root: Default::default(), + parent_hash: Default::default(), + fee_recipient: Default::default(), + prev_randao: Default::default(), + block_number: 0, + gas_limit: 0, + timestamp: 0, + extra_data: Default::default(), + base_fee_per_gas: Default::default(), + }), + diff: Default::default(), + metadata: Default::default(), + }]; + + let messages = FakeConnector::from(flashblocks.iter().map(to_message)); + let ws_url = "http://localhost".parse().unwrap(); + let stream = WsFlashBlockStream::with_connector(ws_url, messages); + + let actual_messages: Vec<_> = stream.map(Result::unwrap).collect().await; + let expected_messages = flashblocks.to_vec(); + + assert_eq!(actual_messages, expected_messages); + } +}