diff --git a/msg-socket/src/pub/driver.rs b/msg-socket/src/pub/driver.rs index 6d4a9ca..e0cba67 100644 --- a/msg-socket/src/pub/driver.rs +++ b/msg-socket/src/pub/driver.rs @@ -14,7 +14,7 @@ use super::{ PubError, PubMessage, PubOptions, SocketState, session::SubscriberSession, trie::PrefixTrie, }; use crate::{ConnectionHookErased, hooks}; -use msg_transport::{Address, PeerAddress, Transport}; +use msg_transport::{Address, MeteredIo, PeerAddress, Transport}; use msg_wire::pubsub; /// The driver for the publisher socket. This is responsible for accepting incoming connections, @@ -28,7 +28,7 @@ pub(crate) struct PubDriver, A: Address> { /// The publisher options (shared with the socket) pub(super) options: Arc, /// The publisher socket state, shared with the socket front-end. - pub(crate) state: Arc, + pub(crate) state: Arc>, /// Optional connection hook. pub(super) hook: Option>>, /// A set of pending incoming connections, represented by [`Transport::Accept`]. @@ -58,13 +58,15 @@ where Ok((stream, _addr)) => { info!("connection hook passed"); - let framed = Framed::new(stream, pubsub::Codec::new()); + let metered = + MeteredIo::new(stream, Arc::clone(&this.state.transport_stats)); + let framed = Framed::new(metered, pubsub::Codec::new()); let session = SubscriberSession { seq: 0, session_id: this.id_counter, from_socket_bcast: this.from_socket_bcast.resubscribe().into(), - state: Arc::clone(&this.state), + stats: this.state.stats.clone(), pending_egress: None, conn: framed, topic_filter: PrefixTrie::new(), @@ -158,13 +160,14 @@ where self.hook_tasks.spawn(fut.with_span(span)); } else { - let framed = Framed::new(io, pubsub::Codec::new()); + let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats)); + let framed = Framed::new(metered, pubsub::Codec::new()); let session = SubscriberSession { seq: 0, session_id: self.id_counter, from_socket_bcast: self.from_socket_bcast.resubscribe().into(), - state: Arc::clone(&self.state), + stats: self.state.stats.clone(), pending_egress: None, conn: framed, topic_filter: PrefixTrie::new(), diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index 40f727d..d8c66fd 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -1,4 +1,4 @@ -use std::{io, time::Duration}; +use std::{io, sync::Arc, time::Duration}; use bytes::Bytes; use msg_common::constants::KiB; @@ -13,6 +13,7 @@ pub use socket::*; mod stats; use crate::{Profile, stats::SocketStats}; +use arc_swap::ArcSwap; use stats::PubStats; mod trie; @@ -208,9 +209,22 @@ impl PubMessage { } /// The publisher socket state, shared between the backend task and the socket. -#[derive(Debug, Default)] -pub(crate) struct SocketState { - pub(crate) stats: SocketStats, +/// Generic over the transport-level stats type. +#[derive(Debug)] +pub(crate) struct SocketState { + pub(crate) stats: Arc>, + /// The transport-level stats. We wrap the inner stats in an `Arc` + /// for cheap clone on read. + pub(crate) transport_stats: Arc>, +} + +impl Default for SocketState { + fn default() -> Self { + Self { + stats: Arc::new(SocketStats::default()), + transport_stats: Arc::new(ArcSwap::from_pointee(S::default())), + } + } } #[cfg(test)] diff --git a/msg-socket/src/pub/session.rs b/msg-socket/src/pub/session.rs index 64d230a..aacad00 100644 --- a/msg-socket/src/pub/session.rs +++ b/msg-socket/src/pub/session.rs @@ -11,9 +11,12 @@ use tokio_stream::wrappers::BroadcastStream; use tokio_util::codec::Framed; use tracing::{debug, error, trace, warn}; -use super::{PubMessage, SocketState, trie::PrefixTrie}; +use super::{PubMessage, trie::PrefixTrie}; use msg_wire::pubsub; +use super::stats::PubStats; +use crate::stats::SocketStats; + /// A subscriber session. This struct represents a single subscriber session, which is a /// connection to a subscriber. This struct is responsible for handling incoming and outgoing /// messages, as well as managing the connection state. @@ -26,8 +29,8 @@ pub(super) struct SubscriberSession { pub(super) from_socket_bcast: BroadcastStream, /// Messages queued to be sent on the connection pub(super) pending_egress: Option, - /// The socket state, shared between the backend task and the socket. - pub(super) state: Arc, + /// The socket stats. + pub(super) stats: Arc>, /// The framed connection. pub(super) conn: Framed, /// The topic filter (a prefix trie that works with strings) @@ -76,7 +79,7 @@ impl SubscriberSession { impl Drop for SubscriberSession { fn drop(&mut self) { - self.state.stats.specific.decrement_active_clients(); + self.stats.specific.decrement_active_clients(); } } @@ -130,7 +133,7 @@ impl Future for SubscriberSession { match this.conn.start_send_unpin(msg) { Ok(_) => { - this.state.stats.specific.increment_tx(msg_len); + this.stats.specific.increment_tx(msg_len); } Err(e) => { error!(err = ?e, "Failed to send message to socket"); diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 5982943..55c178b 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -1,5 +1,6 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; +use arc_swap::Guard; use bytes::Bytes; use futures::stream::FuturesUnordered; use tokio::{ @@ -28,7 +29,7 @@ pub struct PubSocket, A: Address> { /// The reply socket options, shared with the driver. options: Arc, /// The reply socket state, shared with the driver. - state: Arc, + state: Arc>, /// The transport used by this socket. This value is temporary and will be moved /// to the driver task once the socket is bound. transport: Option, @@ -89,7 +90,7 @@ where to_sessions_bcast: None, options: Arc::new(options), transport: Some(transport), - state: Arc::new(SocketState::default()), + state: Arc::new(SocketState::::default()), hook: None, compressor: None, } @@ -212,6 +213,11 @@ where &self.state.stats.specific } + /// Get the latest transport-level stats snapshot. + pub fn transport_stats(&self) -> Guard> { + self.state.transport_stats.load() + } + /// Returns the local address this socket is bound to. `None` if the socket is not bound. pub fn local_addr(&self) -> Option<&A> { self.local_addr.as_ref()