From a52460926263717610424a170eef0d25c286059e Mon Sep 17 00:00:00 2001 From: bennyhodl Date: Wed, 2 Oct 2024 23:17:48 -0400 Subject: [PATCH] feat: Async support for dlc-manager --- dlc-manager/Cargo.toml | 2 + dlc-manager/src/lib.rs | 14 ++ dlc-manager/src/manager.rs | 344 +++++++++++++++++++++++++++++++++++++ 3 files changed, 360 insertions(+) diff --git a/dlc-manager/Cargo.toml b/dlc-manager/Cargo.toml index d14423d1..8fb44e19 100644 --- a/dlc-manager/Cargo.toml +++ b/dlc-manager/Cargo.toml @@ -14,6 +14,7 @@ std = ["dlc/std", "dlc-messages/std", "dlc-trie/std", "bitcoin/std", "lightning/ fuzztarget = ["rand_chacha"] parallel = ["dlc-trie/parallel"] use-serde = ["serde", "dlc/use-serde", "dlc-messages/use-serde", "dlc-trie/use-serde"] +async = ["dep:futures"] [dependencies] async-trait = "0.1.50" @@ -21,6 +22,7 @@ bitcoin = { version = "0.32.2", default-features = false } dlc = { version = "0.6.0", default-features = false, path = "../dlc" } dlc-messages = { version = "0.6.0", default-features = false, path = "../dlc-messages" } dlc-trie = { version = "0.6.0", default-features = false, path = "../dlc-trie" } +futures = { version = "0.3.30", optional = true } hex = { package = "hex-conservative", version = "0.1" } lightning = { version = "0.0.124", default-features = false, features = ["grind_signatures"] } log = "0.4.14" diff --git a/dlc-manager/src/lib.rs b/dlc-manager/src/lib.rs index b5250c00..ee74fdcf 100644 --- a/dlc-manager/src/lib.rs +++ b/dlc-manager/src/lib.rs @@ -230,9 +230,23 @@ pub trait Oracle { /// Returns the public key of the oracle. fn get_public_key(&self) -> XOnlyPublicKey; /// Returns the announcement for the event with the given id if found. + #[cfg(not(feature = "async"))] fn get_announcement(&self, event_id: &str) -> Result; /// Returns the attestation for the event with the given id if found. + #[cfg(not(feature = "async"))] fn get_attestation(&self, event_id: &str) -> Result; + /// Returns the announcement for the event with the given id if found. + #[cfg(feature = "async")] + fn get_announcement( + &self, + event_id: &str, + ) -> impl std::future::Future> + Send; + /// Returns the attestation for the event with the given id if found. + #[cfg(feature = "async")] + fn get_attestation( + &self, + event_id: &str, + ) -> impl std::future::Future> + Send; } /// Represents a UTXO. diff --git a/dlc-manager/src/manager.rs b/dlc-manager/src/manager.rs index 9335970d..f7f6c092 100644 --- a/dlc-manager/src/manager.rs +++ b/dlc-manager/src/manager.rs @@ -31,6 +31,10 @@ use dlc_messages::channel::{ }; use dlc_messages::oracle_msgs::{OracleAnnouncement, OracleAttestation}; use dlc_messages::{AcceptDlc, Message as DlcMessage, OfferDlc, SignDlc}; +#[cfg(feature = "async")] +use futures::stream::FuturesUnordered; +#[cfg(feature = "async")] +use futures::StreamExt; use hex::DisplayHex; use lightning::chain::chaininterface::FeeEstimator; use lightning::ln::chan_utils::{ @@ -269,6 +273,7 @@ where /// and an OfferDlc message returned. /// /// This function will fetch the oracle announcements from the oracle. + #[cfg(not(feature = "async"))] pub fn send_offer( &self, contract_input: &ContractInput, @@ -283,6 +288,29 @@ where self.send_offer_with_announcements(contract_input, counter_party, oracle_announcements) } + /// Function called to create a new DLC. The offered contract will be stored + /// and an OfferDlc message returned. + /// + /// This function will fetch the oracle announcements from the oracle. + #[cfg(feature = "async")] + pub async fn send_offer( + &self, + contract_input: &ContractInput, + counter_party: PublicKey, + ) -> Result { + let oracle_announcements = contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::>() + .collect::, Error>>>() + .await + .into_iter() + .collect::, Error>>()?; + + self.send_offer_with_announcements(contract_input, counter_party, oracle_announcements) + } + /// Function called to create a new DLC. The offered contract will be stored /// and an OfferDlc message returned. /// @@ -375,6 +403,7 @@ where /// Function to call to check the state of the currently executing DLCs and /// update them if possible. + #[cfg(not(feature = "async"))] pub fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { self.check_signed_contracts()?; self.check_confirmed_contracts()?; @@ -387,6 +416,21 @@ where Ok(()) } + /// Function to call to check the state of the currently executing DLCs and + /// update them if possible. + #[cfg(feature = "async")] + pub async fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { + self.check_signed_contracts()?; + self.check_confirmed_contracts().await?; + self.check_preclosed_contracts()?; + + if check_channels { + self.channel_checks().await?; + } + + Ok(()) + } + fn on_offer_message( &self, offered_message: &OfferDlc, @@ -470,6 +514,7 @@ where Ok(()) } + #[cfg(not(feature = "async"))] fn get_oracle_announcements( &self, oracle_inputs: &OracleInput, @@ -486,6 +531,28 @@ where Ok(announcements) } + #[cfg(feature = "async")] + async fn get_oracle_announcements( + &self, + oracle_inputs: &OracleInput, + ) -> Result, Error> { + let mut announcements = Vec::new(); + for pubkey in &oracle_inputs.public_keys { + let oracle = self + .oracles + .get(pubkey) + .ok_or_else(|| Error::InvalidParameters("Unknown oracle public key".to_string()))?; + announcements.push( + oracle + .get_announcement(&oracle_inputs.event_id) + .await? + .clone(), + ); + } + + Ok(announcements) + } + fn sign_fail_on_error( &self, accepted_contract: AcceptedContract, @@ -547,6 +614,7 @@ where Ok(()) } + #[cfg(not(feature = "async"))] fn check_confirmed_contracts(&self) -> Result<(), Error> { for c in self.store.get_confirmed_contracts()? { // Confirmed contracts from channel are processed in channel specific methods. @@ -564,7 +632,26 @@ where Ok(()) } + #[cfg(feature = "async")] + async fn check_confirmed_contracts(&self) -> Result<(), Error> { + for c in self.store.get_confirmed_contracts()? { + // Confirmed contracts from channel are processed in channel specific methods. + if c.channel_id.is_some() { + continue; + } + if let Err(e) = self.check_confirmed_contract(&c).await { + error!( + "Error checking confirmed contract {}: {}", + c.accepted_contract.get_contract_id_string(), + e + ) + } + } + + Ok(()) + } + #[cfg(not(feature = "async"))] fn get_closable_contract_info<'a>( &'a self, contract: &'a SignedContract, @@ -600,7 +687,53 @@ where } None } + #[cfg(feature = "async")] + async fn get_closable_contract_info<'a>( + &'a self, + contract: &'a SignedContract, + ) -> ClosableContractInfo<'a> { + let contract_infos = &contract.accepted_contract.offered_contract.contract_info; + let adaptor_infos = &contract.accepted_contract.adaptor_infos; + for (contract_info, adaptor_info) in contract_infos.iter().zip(adaptor_infos.iter()) { + let matured: Vec<_> = contract_info + .oracle_announcements + .iter() + .filter(|x| { + (x.oracle_event.event_maturity_epoch as u64) <= self.time.unix_time_now() + }) + .enumerate() + .collect(); + + if matured.len() >= contract_info.threshold { + let attestations: Vec<_> = matured + .iter() + .filter_map(|(i, announcement)| { + let oracle = self.oracles.get(&announcement.oracle_public_key)?; + Some((*i, oracle, announcement.oracle_event.event_id.clone())) + }) + .collect::>() + .into_iter() + .map(|(i, oracle, event_id)| async move { + oracle + .get_attestation(&event_id) + .await + .ok() + .map(|attestation| (i, attestation)) + }) + .collect::>() + .filter_map(|result| async move { result }) + .collect() + .await; + + if attestations.len() >= contract_info.threshold { + return Some((contract_info, adaptor_info, attestations)); + } + } + } + None + } + #[cfg(not(feature = "async"))] fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { let closable_contract_info = self.get_closable_contract_info(contract); if let Some((contract_info, adaptor_info, attestations)) = closable_contract_info { @@ -639,6 +772,45 @@ where Ok(()) } + #[cfg(feature = "async")] + async fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { + let closable_contract_info = self.get_closable_contract_info(contract).await; + if let Some((contract_info, adaptor_info, attestations)) = closable_contract_info { + let offer = &contract.accepted_contract.offered_contract; + let signer = self.signer_provider.derive_contract_signer(offer.keys_id)?; + let cet = crate::contract_updater::get_signed_cet( + &self.secp, + contract, + contract_info, + adaptor_info, + &attestations, + &signer, + )?; + match self.close_contract( + contract, + cet, + attestations.iter().map(|x| x.1.clone()).collect(), + ) { + Ok(closed_contract) => { + self.store.update_contract(&closed_contract)?; + return Ok(()); + } + Err(e) => { + warn!( + "Failed to close contract {}: {}", + contract.accepted_contract.get_contract_id_string(), + e + ); + return Err(e); + } + } + } + + self.check_refund(contract)?; + + Ok(()) + } + /// Manually close a contract with the oracle attestations. pub fn close_confirmed_contract( &self, @@ -902,6 +1074,7 @@ where { /// Create a new channel offer and return the [`dlc_messages::channel::OfferChannel`] /// message to be sent to the `counter_party`. + #[cfg(not(feature = "async"))] pub fn offer_channel( &self, contract_input: &ContractInput, @@ -936,6 +1109,47 @@ where Ok(msg) } + /// Create a new channel offer and return the [`dlc_messages::channel::OfferChannel`] + /// message to be sent to the `counter_party`. + #[cfg(feature = "async")] + pub async fn offer_channel( + &self, + contract_input: &ContractInput, + counter_party: PublicKey, + ) -> Result { + let oracle_announcements = contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::>() + .collect::, Error>>>() + .await + .into_iter() + .collect::, Error>>()?; + + let (offered_channel, offered_contract) = crate::channel_updater::offer_channel( + &self.secp, + contract_input, + &counter_party, + &oracle_announcements, + CET_NSEQUENCE, + REFUND_DELAY, + &self.wallet, + &self.signer_provider, + &self.blockchain, + &self.time, + crate::utils::get_new_temporary_id(), + )?; + + let msg = offered_channel.get_offer_channel_msg(&offered_contract); + + self.store.upsert_channel( + Channel::Offered(offered_channel), + Some(Contract::Offered(offered_contract)), + )?; + + Ok(msg) + } /// Reject a channel that was offered. Returns the [`dlc_messages::channel::Reject`] /// message to be sent as well as the public key of the offering node. @@ -1084,6 +1298,7 @@ where /// Returns a [`RenewOffer`] message as well as the [`PublicKey`] of the /// counter party's node to offer the establishment of a new contract in the /// channel. + #[cfg(not(feature = "async"))] pub fn renew_offer( &self, channel_id: &ChannelId, @@ -1121,6 +1336,52 @@ where Ok((msg, counter_party)) } + /// Returns a [`RenewOffer`] message as well as the [`PublicKey`] of the + /// counter party's node to offer the establishment of a new contract in the + /// channel. + #[cfg(feature = "async")] + pub async fn renew_offer( + &self, + channel_id: &ChannelId, + counter_payout: u64, + contract_input: &ContractInput, + ) -> Result<(RenewOffer, PublicKey), Error> { + let mut signed_channel = + get_channel_in_state!(self, channel_id, Signed, None as Option)?; + + // TODO: helper function + let oracle_announcements = contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::>() + .collect::, Error>>>() + .await + .into_iter() + .collect::, Error>>()?; + + let (msg, offered_contract) = crate::channel_updater::renew_offer( + &self.secp, + &mut signed_channel, + contract_input, + oracle_announcements, + counter_payout, + REFUND_DELAY, + PEER_TIMEOUT, + CET_NSEQUENCE, + &self.signer_provider, + &self.time, + )?; + + let counter_party = offered_contract.counter_party; + + self.store.upsert_channel( + Channel::Signed(signed_channel), + Some(Contract::Offered(offered_contract)), + )?; + + Ok((msg, counter_party)) + } /// Accept an offer to renew the contract in the channel. Returns the /// [`RenewAccept`] message to be sent to the peer with the returned @@ -1292,6 +1553,7 @@ where Ok(()) } + #[cfg(not(feature = "async"))] fn try_finalize_closing_established_channel( &self, signed_channel: SignedChannel, @@ -1353,6 +1615,69 @@ where Ok(()) } + #[cfg(feature = "async")] + async fn try_finalize_closing_established_channel( + &self, + signed_channel: SignedChannel, + ) -> Result<(), Error> { + let (buffer_tx, contract_id, &is_initiator) = get_signed_channel_state!( + signed_channel, + Closing, + buffer_transaction, + contract_id, + is_initiator + )?; + + if self + .blockchain + .get_transaction_confirmations(&buffer_tx.compute_txid())? + >= CET_NSEQUENCE + { + log::info!( + "Buffer transaction for contract {} has enough confirmations to spend from it", + serialize_hex(&contract_id) + ); + + let confirmed_contract = + get_contract_in_state!(self, &contract_id, Confirmed, None as Option)?; + + let (contract_info, adaptor_info, attestations) = self + .get_closable_contract_info(&confirmed_contract) + .await + .ok_or_else(|| { + Error::InvalidState("Could not get information to close contract".to_string()) + })?; + + let (signed_cet, closed_channel) = + crate::channel_updater::finalize_unilateral_close_settled_channel( + &self.secp, + &signed_channel, + &confirmed_contract, + contract_info, + &attestations, + adaptor_info, + &self.signer_provider, + is_initiator, + )?; + + let closed_contract = self.close_contract( + &confirmed_contract, + signed_cet, + attestations.iter().map(|x| &x.1).cloned().collect(), + )?; + + self.chain_monitor + .lock() + .unwrap() + .cleanup_channel(signed_channel.channel_id); + + self.store + .upsert_channel(closed_channel, Some(closed_contract))?; + } + + Ok(()) + } + fn on_offer_channel( &self, offer_channel: &OfferChannel, @@ -2079,6 +2404,25 @@ where Ok(()) } + #[cfg(feature = "async")] + async fn channel_checks(&self) -> Result<(), Error> { + let established_closing_channels = self + .store + .get_signed_channels(Some(SignedChannelStateType::Closing))?; + + for channel in established_closing_channels { + if let Err(e) = self.try_finalize_closing_established_channel(channel).await { + error!("Error trying to close established channel: {}", e); + } + } + + if let Err(e) = self.check_for_timed_out_channels() { + error!("Error checking timed out channels {}", e); + } + self.check_for_watched_tx() + } + + #[cfg(not(feature = "async"))] fn channel_checks(&self) -> Result<(), Error> { let established_closing_channels = self .store