Skip to content

Commit

Permalink
feat: channels for receiving dlc messages
Browse files Browse the repository at this point in the history
  • Loading branch information
bennyhodl committed Jan 15, 2025
1 parent ec130e6 commit aa9361e
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 7 deletions.
4 changes: 4 additions & 0 deletions dlc-messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ license-file = "../LICENSE"
name = "dlc-messages"
repository = "https://github.com/p2pderivatives/rust-dlc/tree/master/dlc-messages"
version = "0.7.1"
edition = "2018"

[features]
default = ["std"]
Expand All @@ -15,14 +16,17 @@ use-serde = ["serde", "secp256k1-zkp/serde", "bitcoin/serde"]

[dependencies]
bitcoin = { version = "0.32.2", default-features = false }
crossbeam = "0.8.4"
dlc = { version = "0.7.1", path = "../dlc", default-features = false }
lightning = { version = "0.0.125", default-features = false }
secp256k1-zkp = {version = "0.11.0"}
serde = {version = "1.0", features = ["derive"], optional = true}
tokio = { version = "1.43.0", features = ["sync"] }

[dev-dependencies]
bitcoin = { version = "0.32.2", default-features = false, features = ["serde"] }
dlc-messages = {path = "./", default-features = false, features = ["use-serde"]}
secp256k1-zkp = {version = "0.11.0", features = ["serde", "global-context"]}
serde = {version = "1.0", features = ["derive"]}
serde_json = "1.0"
tokio = { version = "1.43.0", features = ["sync", "macros", "rt"] }
2 changes: 1 addition & 1 deletion dlc-messages/src/contract_msgs.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Structure containing information about contract details.
use crate::oracle_msgs::OracleInfo;
use bitcoin::Amount;
use lightning::ln::msgs::DecodeError;
use lightning::util::ser::{Readable, Writeable, Writer};
use oracle_msgs::OracleInfo;

#[derive(Clone, PartialEq, Debug, Eq)]
#[cfg_attr(
Expand Down
2 changes: 2 additions & 0 deletions dlc-messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ extern crate bitcoin;
extern crate dlc;
extern crate lightning;
extern crate secp256k1_zkp;
extern crate tokio;

#[macro_use]
pub mod ser_macros;
pub mod ser_impls;
Expand Down
169 changes: 163 additions & 6 deletions dlc-messages/src/message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
sync::Mutex,
sync::{Arc, Mutex},
};

use lightning::ln::features::{InitFeatures, NodeFeatures};
Expand All @@ -17,6 +17,8 @@ use lightning::{
util::ser::{Readable, Writeable, MAX_BUF_SIZE},
};
use secp256k1_zkp::PublicKey;
use std::sync::mpsc;
use tokio::sync::broadcast;

use crate::{
segmentation::{get_segments, segment_reader::SegmentReader},
Expand All @@ -31,6 +33,15 @@ pub struct MessageHandler {
msg_events: Mutex<VecDeque<(PublicKey, WireMessage)>>,
msg_received: Mutex<Vec<(PublicKey, Message)>>,
segment_readers: Mutex<HashMap<PublicKey, SegmentReader>>,
// Async message handling
message_sender: broadcast::Sender<Message>,
// Sync message handling, one consumer
sync_message_sender: mpsc::Sender<Message>,
/// Receiver for messages sent from the sync thread. Mutex for thread safety.
pub sync_message_receiver: Arc<Mutex<mpsc::Receiver<Message>>>,
// Sync message handling, multiple consumers
crossbeam_sender: crossbeam::channel::Sender<Message>,
crossbeam_receiver: crossbeam::channel::Receiver<Message>,
}

impl Default for MessageHandler {
Expand All @@ -42,10 +53,18 @@ impl Default for MessageHandler {
impl MessageHandler {
/// Creates a new instance of a [`MessageHandler`]
pub fn new() -> Self {
let (message_sender, _) = broadcast::channel(100);
let (sync_message_sender, sync_message_receiver) = mpsc::channel();
let (crossbeam_sender, crossbeam_receiver) = crossbeam::channel::unbounded();
MessageHandler {
msg_events: Mutex::new(VecDeque::new()),
msg_received: Mutex::new(Vec::new()),
segment_readers: Mutex::new(HashMap::new()),
message_sender,
sync_message_sender,
sync_message_receiver: Arc::new(Mutex::new(sync_message_receiver)),
crossbeam_sender,
crossbeam_receiver,
}
}

Expand All @@ -62,6 +81,7 @@ impl MessageHandler {
/// [`lightning::ln::peer_handler::PeerManager::process_events`] is next called.
pub fn send_message(&self, node_id: PublicKey, msg: Message) {
if msg.serialized_length() > MAX_BUF_SIZE {
println!("Sending segmented message");
let (seg_start, seg_chunks) = get_segments(msg.encode(), msg.type_id());
let mut msg_events = self.msg_events.lock().unwrap();
msg_events.push_back((node_id, WireMessage::SegmentStart(seg_start)));
Expand All @@ -76,6 +96,16 @@ impl MessageHandler {
}
}

/// Returns the notifications channel.
pub fn notifications(&self) -> broadcast::Receiver<Message> {
self.message_sender.subscribe()
}

/// Returns the crossbeam notifications channel.
pub fn crossbeam_notifications(&self) -> crossbeam::channel::Receiver<Message> {
self.crossbeam_receiver.clone()
}

/// Returns whether the message handler has any message to be sent.
pub fn has_pending_messages(&self) -> bool {
!self.msg_events.lock().unwrap().is_empty()
Expand Down Expand Up @@ -167,6 +197,7 @@ impl CustomMessageHandler for MessageHandler {
let segment_reader = segment_readers.entry(*org).or_default();

if segment_reader.expecting_chunk() {
println!("Expecting segment chunk");
match msg {
WireMessage::SegmentChunk(s) => {
if let Some(msg) = segment_reader
Expand All @@ -184,7 +215,13 @@ impl CustomMessageHandler for MessageHandler {
})?
.expect("to have a message")
{
self.msg_received.lock().unwrap().push((*org, m));
self.msg_received.lock().unwrap().push((*org, m.clone()));
// Tokio sender
let _ = self.message_sender.send(m.clone());
// Sync sender
let _ = self.sync_message_sender.send(m.clone());
// Crossbeam sender
let _ = self.crossbeam_sender.send(m.clone());
} else {
return Err(to_ln_error(
"Unexpected message type",
Expand All @@ -203,17 +240,30 @@ impl CustomMessageHandler for MessageHandler {
}

match msg {
WireMessage::Message(m) => self.msg_received.lock().unwrap().push((*org, m)),
WireMessage::SegmentStart(s) => segment_reader
.process_segment_start(s)
.map_err(|e| to_ln_error(e, "Error processing segment start"))?,
WireMessage::Message(m) => {
self.msg_received.lock().unwrap().push((*org, m.clone()));
// Tokio sender
let _ = self.message_sender.send(m.clone());
// Sync sender
let _ = self.sync_message_sender.send(m.clone());
// Crossbeam sender
let _ = self.crossbeam_sender.send(m.clone());
}
WireMessage::SegmentStart(s) => {
println!("Processing segment start");
segment_reader
.process_segment_start(s)
.map_err(|e| to_ln_error(e, "Error processing segment start"))?
}
WireMessage::SegmentChunk(_) => {
println!("Processing segment chunk");
return Err(LightningError {
err: "Received a SegmentChunk while not expecting one.".to_string(),
action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None },
});
}
};

Ok(())
}

Expand Down Expand Up @@ -371,4 +421,111 @@ mod tests {
panic!("Expected an accept message");
}
}

#[tokio::test]
async fn notifications_test() {
let input = include_str!("./test_inputs/offer_msg.json");
let msg: OfferDlc = serde_json::from_str(input).unwrap();
let handler = MessageHandler::new();
let notifications = handler.notifications();
handler
.handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk())
.unwrap();
assert_eq!(notifications.len(), 1);
}

#[tokio::test]
async fn notifications_segment_test() {
let input1 = include_str!("./test_inputs/segment_start_msg.json");
let input2 = include_str!("./test_inputs/segment_chunk_msg.json");
let segment_start: SegmentStart = serde_json::from_str(input1).unwrap();
let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap();

let handler = MessageHandler::new();
let mut notifications = handler.notifications();
handler
.handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk())
.expect("to be able to process segment start");
handler
.handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk())
.expect("to be able to process segment start");

assert_eq!(1, notifications.len());
let msg = notifications.recv().await.unwrap();
assert!(matches!(msg, Message::Accept(_)));
}

#[test]
fn sync_message_test() {
let input = include_str!("./test_inputs/offer_msg.json");
let msg: OfferDlc = serde_json::from_str(input).unwrap();
let handler = MessageHandler::new();
handler
.handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk())
.expect("to be able to process segment start");
let msg = handler
.sync_message_receiver
.lock()
.unwrap()
.recv()
.unwrap();
assert!(matches!(msg, Message::Offer(_)));
}

#[test]
fn sync_segment_test() {
let input1 = include_str!("./test_inputs/segment_start_msg.json");
let input2 = include_str!("./test_inputs/segment_chunk_msg.json");
let segment_start: SegmentStart = serde_json::from_str(input1).unwrap();
let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap();

let handler = MessageHandler::new();
handler
.handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk())
.expect("to be able to process segment start");
handler
.handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk())
.expect("to be able to process segment start");

let msg = handler
.sync_message_receiver
.lock()
.unwrap()
.recv()
.unwrap();
assert!(matches!(msg, Message::Accept(_)));
}

#[test]
fn crossbeam_message_test() {
let input = include_str!("./test_inputs/offer_msg.json");
let msg: OfferDlc = serde_json::from_str(input).unwrap();
let handler = MessageHandler::new();
handler
.handle_custom_message(WireMessage::Message(Message::Offer(msg)), &some_pk())
.expect("to be able to process segment start");
println!("Waiting for message");
let msg = handler.crossbeam_notifications().recv().unwrap();
println!("Received message");
assert!(matches!(msg, Message::Offer(_)));
}

#[test]
fn crossbeam_segment_test() {
let input1 = include_str!("./test_inputs/segment_start_msg.json");
let input2 = include_str!("./test_inputs/segment_chunk_msg.json");
let segment_start: SegmentStart = serde_json::from_str(input1).unwrap();
let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap();

let handler = MessageHandler::new();
handler
.handle_custom_message(WireMessage::SegmentStart(segment_start), &some_pk())
.expect("to be able to process segment start");
handler
.handle_custom_message(WireMessage::SegmentChunk(segment_chunk), &some_pk())
.expect("to be able to process segment start");

let msg = handler.crossbeam_notifications().recv().unwrap();
assert!(matches!(msg, Message::Accept(_)));
}
}

0 comments on commit aa9361e

Please sign in to comment.