From aa9361ed8ed8a8e6f6b24a1a624deafc55e0fd85 Mon Sep 17 00:00:00 2001 From: bennyhodl Date: Wed, 15 Jan 2025 12:43:18 -0500 Subject: [PATCH] feat: channels for receiving dlc messages --- dlc-messages/Cargo.toml | 4 + dlc-messages/src/contract_msgs.rs | 2 +- dlc-messages/src/lib.rs | 2 + dlc-messages/src/message_handler.rs | 169 +++++++++++++++++++++++++++- 4 files changed, 170 insertions(+), 7 deletions(-) diff --git a/dlc-messages/Cargo.toml b/dlc-messages/Cargo.toml index a2346c65..3285a4ba 100644 --- a/dlc-messages/Cargo.toml +++ b/dlc-messages/Cargo.toml @@ -6,6 +6,7 @@ license-file = "../LICENSE" name = "dlc-messages" repository = "https://github.com/p2pderivatives/rust-dlc/tree/master/dlc-messages" version = "0.7.1" +edition = "2018" [features] default = ["std"] @@ -15,10 +16,12 @@ use-serde = ["serde", "secp256k1-zkp/serde", "bitcoin/serde"] [dependencies] bitcoin = { version = "0.32.2", default-features = false } +crossbeam = "0.8.4" dlc = { version = "0.7.1", path = "../dlc", default-features = false } lightning = { version = "0.0.125", default-features = false } secp256k1-zkp = {version = "0.11.0"} serde = {version = "1.0", features = ["derive"], optional = true} +tokio = { version = "1.43.0", features = ["sync"] } [dev-dependencies] bitcoin = { version = "0.32.2", default-features = false, features = ["serde"] } @@ -26,3 +29,4 @@ dlc-messages = {path = "./", default-features = false, features = ["use-serde"]} secp256k1-zkp = {version = "0.11.0", features = ["serde", "global-context"]} serde = {version = "1.0", features = ["derive"]} serde_json = "1.0" +tokio = { version = "1.43.0", features = ["sync", "macros", "rt"] } diff --git a/dlc-messages/src/contract_msgs.rs b/dlc-messages/src/contract_msgs.rs index 8347444f..3b2e2af8 100644 --- a/dlc-messages/src/contract_msgs.rs +++ b/dlc-messages/src/contract_msgs.rs @@ -1,9 +1,9 @@ //! Structure containing information about contract details. +use crate::oracle_msgs::OracleInfo; use bitcoin::Amount; use lightning::ln::msgs::DecodeError; use lightning::util::ser::{Readable, Writeable, Writer}; -use oracle_msgs::OracleInfo; #[derive(Clone, PartialEq, Debug, Eq)] #[cfg_attr( diff --git a/dlc-messages/src/lib.rs b/dlc-messages/src/lib.rs index c8d47859..5a9eca30 100644 --- a/dlc-messages/src/lib.rs +++ b/dlc-messages/src/lib.rs @@ -14,6 +14,8 @@ extern crate bitcoin; extern crate dlc; extern crate lightning; extern crate secp256k1_zkp; +extern crate tokio; + #[macro_use] pub mod ser_macros; pub mod ser_impls; diff --git a/dlc-messages/src/message_handler.rs b/dlc-messages/src/message_handler.rs index 6e587871..48a0a9ac 100644 --- a/dlc-messages/src/message_handler.rs +++ b/dlc-messages/src/message_handler.rs @@ -3,7 +3,7 @@ use std::{ collections::{HashMap, VecDeque}, fmt::Display, - sync::Mutex, + sync::{Arc, Mutex}, }; use lightning::ln::features::{InitFeatures, NodeFeatures}; @@ -17,6 +17,8 @@ use lightning::{ util::ser::{Readable, Writeable, MAX_BUF_SIZE}, }; use secp256k1_zkp::PublicKey; +use std::sync::mpsc; +use tokio::sync::broadcast; use crate::{ segmentation::{get_segments, segment_reader::SegmentReader}, @@ -31,6 +33,15 @@ pub struct MessageHandler { msg_events: Mutex>, msg_received: Mutex>, segment_readers: Mutex>, + // Async message handling + message_sender: broadcast::Sender, + // Sync message handling, one consumer + sync_message_sender: mpsc::Sender, + /// Receiver for messages sent from the sync thread. Mutex for thread safety. + pub sync_message_receiver: Arc>>, + // Sync message handling, multiple consumers + crossbeam_sender: crossbeam::channel::Sender, + crossbeam_receiver: crossbeam::channel::Receiver, } impl Default for MessageHandler { @@ -42,10 +53,18 @@ impl Default for MessageHandler { impl MessageHandler { /// Creates a new instance of a [`MessageHandler`] pub fn new() -> Self { + let (message_sender, _) = broadcast::channel(100); + let (sync_message_sender, sync_message_receiver) = mpsc::channel(); + let (crossbeam_sender, crossbeam_receiver) = crossbeam::channel::unbounded(); MessageHandler { msg_events: Mutex::new(VecDeque::new()), msg_received: Mutex::new(Vec::new()), segment_readers: Mutex::new(HashMap::new()), + message_sender, + sync_message_sender, + sync_message_receiver: Arc::new(Mutex::new(sync_message_receiver)), + crossbeam_sender, + crossbeam_receiver, } } @@ -62,6 +81,7 @@ impl MessageHandler { /// [`lightning::ln::peer_handler::PeerManager::process_events`] is next called. pub fn send_message(&self, node_id: PublicKey, msg: Message) { if msg.serialized_length() > MAX_BUF_SIZE { + println!("Sending segmented message"); let (seg_start, seg_chunks) = get_segments(msg.encode(), msg.type_id()); let mut msg_events = self.msg_events.lock().unwrap(); msg_events.push_back((node_id, WireMessage::SegmentStart(seg_start))); @@ -76,6 +96,16 @@ impl MessageHandler { } } + /// Returns the notifications channel. + pub fn notifications(&self) -> broadcast::Receiver { + self.message_sender.subscribe() + } + + /// Returns the crossbeam notifications channel. + pub fn crossbeam_notifications(&self) -> crossbeam::channel::Receiver { + self.crossbeam_receiver.clone() + } + /// Returns whether the message handler has any message to be sent. pub fn has_pending_messages(&self) -> bool { !self.msg_events.lock().unwrap().is_empty() @@ -167,6 +197,7 @@ impl CustomMessageHandler for MessageHandler { let segment_reader = segment_readers.entry(*org).or_default(); if segment_reader.expecting_chunk() { + println!("Expecting segment chunk"); match msg { WireMessage::SegmentChunk(s) => { if let Some(msg) = segment_reader @@ -184,7 +215,13 @@ impl CustomMessageHandler for MessageHandler { })? .expect("to have a message") { - self.msg_received.lock().unwrap().push((*org, m)); + self.msg_received.lock().unwrap().push((*org, m.clone())); + // Tokio sender + let _ = self.message_sender.send(m.clone()); + // Sync sender + let _ = self.sync_message_sender.send(m.clone()); + // Crossbeam sender + let _ = self.crossbeam_sender.send(m.clone()); } else { return Err(to_ln_error( "Unexpected message type", @@ -203,17 +240,30 @@ impl CustomMessageHandler for MessageHandler { } match msg { - WireMessage::Message(m) => self.msg_received.lock().unwrap().push((*org, m)), - WireMessage::SegmentStart(s) => segment_reader - .process_segment_start(s) - .map_err(|e| to_ln_error(e, "Error processing segment start"))?, + WireMessage::Message(m) => { + self.msg_received.lock().unwrap().push((*org, m.clone())); + // Tokio sender + let _ = self.message_sender.send(m.clone()); + // Sync sender + let _ = self.sync_message_sender.send(m.clone()); + // Crossbeam sender + let _ = self.crossbeam_sender.send(m.clone()); + } + WireMessage::SegmentStart(s) => { + println!("Processing segment start"); + segment_reader + .process_segment_start(s) + .map_err(|e| to_ln_error(e, "Error processing segment start"))? + } WireMessage::SegmentChunk(_) => { + println!("Processing segment chunk"); return Err(LightningError { err: "Received a SegmentChunk while not expecting one.".to_string(), action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None }, }); } }; + Ok(()) } @@ -371,4 +421,111 @@ mod tests { panic!("Expected an accept message"); } } + + #[tokio::test] + async fn notifications_test() { + let input = include_str!("./test_inputs/offer_msg.json"); + let msg: OfferDlc = serde_json::from_str(input).unwrap(); + let handler = MessageHandler::new(); + let notifications = handler.notifications(); + handler + .handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk()) + .unwrap(); + assert_eq!(notifications.len(), 1); + } + + #[tokio::test] + async fn notifications_segment_test() { + let input1 = include_str!("./test_inputs/segment_start_msg.json"); + let input2 = include_str!("./test_inputs/segment_chunk_msg.json"); + let segment_start: SegmentStart = serde_json::from_str(input1).unwrap(); + let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap(); + + let handler = MessageHandler::new(); + let mut notifications = handler.notifications(); + handler + .handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk()) + .expect("to be able to process segment start"); + handler + .handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk()) + .expect("to be able to process segment start"); + + assert_eq!(1, notifications.len()); + let msg = notifications.recv().await.unwrap(); + assert!(matches!(msg, Message::Accept(_))); + } + + #[test] + fn sync_message_test() { + let input = include_str!("./test_inputs/offer_msg.json"); + let msg: OfferDlc = serde_json::from_str(input).unwrap(); + let handler = MessageHandler::new(); + handler + .handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk()) + .expect("to be able to process segment start"); + let msg = handler + .sync_message_receiver + .lock() + .unwrap() + .recv() + .unwrap(); + assert!(matches!(msg, Message::Offer(_))); + } + + #[test] + fn sync_segment_test() { + let input1 = include_str!("./test_inputs/segment_start_msg.json"); + let input2 = include_str!("./test_inputs/segment_chunk_msg.json"); + let segment_start: SegmentStart = serde_json::from_str(input1).unwrap(); + let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap(); + + let handler = MessageHandler::new(); + handler + .handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk()) + .expect("to be able to process segment start"); + handler + .handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk()) + .expect("to be able to process segment start"); + + let msg = handler + .sync_message_receiver + .lock() + .unwrap() + .recv() + .unwrap(); + assert!(matches!(msg, Message::Accept(_))); + } + + #[test] + fn crossbeam_message_test() { + let input = include_str!("./test_inputs/offer_msg.json"); + let msg: OfferDlc = serde_json::from_str(input).unwrap(); + let handler = MessageHandler::new(); + handler + .handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk()) + .expect("to be able to process segment start"); + println!("Waiting for message"); + let msg = handler.crossbeam_notifications().recv().unwrap(); + println!("Received message"); + assert!(matches!(msg, Message::Offer(_))); + } + + #[test] + fn crossbeam_segment_test() { + let input1 = include_str!("./test_inputs/segment_start_msg.json"); + let input2 = include_str!("./test_inputs/segment_chunk_msg.json"); + let segment_start: SegmentStart = serde_json::from_str(input1).unwrap(); + let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap(); + + let handler = MessageHandler::new(); + handler + .handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk()) + .expect("to be able to process segment start"); + handler + .handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk()) + .expect("to be able to process segment start"); + + let msg = handler.crossbeam_notifications().recv().unwrap(); + assert!(matches!(msg, Message::Accept(_))); + } }