Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions msg-socket/src/pub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::{
PubError, PubMessage, PubOptions, SocketState, session::SubscriberSession, trie::PrefixTrie,
};
use crate::{ConnectionHookErased, hooks};
use msg_transport::{Address, PeerAddress, Transport};
use msg_transport::{Address, MeteredIo, PeerAddress, Transport};
use msg_wire::pubsub;

/// The driver for the publisher socket. This is responsible for accepting incoming connections,
Expand All @@ -28,7 +28,7 @@ pub(crate) struct PubDriver<T: Transport<A>, A: Address> {
/// The publisher options (shared with the socket)
pub(super) options: Arc<PubOptions>,
/// The publisher socket state, shared with the socket front-end.
pub(crate) state: Arc<SocketState>,
pub(crate) state: Arc<SocketState<T::Stats>>,
/// Optional connection hook.
pub(super) hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
/// A set of pending incoming connections, represented by [`Transport::Accept`].
Expand Down Expand Up @@ -58,13 +58,15 @@ where
Ok((stream, _addr)) => {
info!("connection hook passed");

let framed = Framed::new(stream, pubsub::Codec::new());
let metered =
MeteredIo::new(stream, Arc::clone(&this.state.transport_stats));
let framed = Framed::new(metered, pubsub::Codec::new());

let session = SubscriberSession {
seq: 0,
session_id: this.id_counter,
from_socket_bcast: this.from_socket_bcast.resubscribe().into(),
state: Arc::clone(&this.state),
stats: this.state.stats.clone(),
pending_egress: None,
conn: framed,
topic_filter: PrefixTrie::new(),
Expand Down Expand Up @@ -158,13 +160,14 @@ where

self.hook_tasks.spawn(fut.with_span(span));
} else {
let framed = Framed::new(io, pubsub::Codec::new());
let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats));
let framed = Framed::new(metered, pubsub::Codec::new());

let session = SubscriberSession {
seq: 0,
session_id: self.id_counter,
from_socket_bcast: self.from_socket_bcast.resubscribe().into(),
state: Arc::clone(&self.state),
stats: self.state.stats.clone(),
pending_egress: None,
conn: framed,
topic_filter: PrefixTrie::new(),
Expand Down
22 changes: 18 additions & 4 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{io, time::Duration};
use std::{io, sync::Arc, time::Duration};

use bytes::Bytes;
use msg_common::constants::KiB;
Expand All @@ -13,6 +13,7 @@ pub use socket::*;

mod stats;
use crate::{Profile, stats::SocketStats};
use arc_swap::ArcSwap;
use stats::PubStats;

mod trie;
Expand Down Expand Up @@ -208,9 +209,22 @@ impl PubMessage {
}

/// The publisher socket state, shared between the backend task and the socket.
#[derive(Debug, Default)]
pub(crate) struct SocketState {
pub(crate) stats: SocketStats<PubStats>,
/// Generic over the transport-level stats type.
#[derive(Debug)]
pub(crate) struct SocketState<S: Default> {
pub(crate) stats: Arc<SocketStats<PubStats>>,
/// The transport-level stats. We wrap the inner stats in an `Arc`
/// for cheap clone on read.
pub(crate) transport_stats: Arc<ArcSwap<S>>,
}

impl<S: Default> Default for SocketState<S> {
fn default() -> Self {
Self {
stats: Arc::new(SocketStats::default()),
transport_stats: Arc::new(ArcSwap::from_pointee(S::default())),
}
}
}

#[cfg(test)]
Expand Down
13 changes: 8 additions & 5 deletions msg-socket/src/pub/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ use tokio_stream::wrappers::BroadcastStream;
use tokio_util::codec::Framed;
use tracing::{debug, error, trace, warn};

use super::{PubMessage, SocketState, trie::PrefixTrie};
use super::{PubMessage, trie::PrefixTrie};
use msg_wire::pubsub;

use super::stats::PubStats;
use crate::stats::SocketStats;

/// A subscriber session. This struct represents a single subscriber session, which is a
/// connection to a subscriber. This struct is responsible for handling incoming and outgoing
/// messages, as well as managing the connection state.
Expand All @@ -26,8 +29,8 @@ pub(super) struct SubscriberSession<Io> {
pub(super) from_socket_bcast: BroadcastStream<PubMessage>,
/// Messages queued to be sent on the connection
pub(super) pending_egress: Option<pubsub::Message>,
/// The socket state, shared between the backend task and the socket.
pub(super) state: Arc<SocketState>,
/// The socket stats.
pub(super) stats: Arc<SocketStats<PubStats>>,
/// The framed connection.
pub(super) conn: Framed<Io, pubsub::Codec>,
/// The topic filter (a prefix trie that works with strings)
Expand Down Expand Up @@ -76,7 +79,7 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> SubscriberSession<Io> {

impl<Io> Drop for SubscriberSession<Io> {
fn drop(&mut self) {
self.state.stats.specific.decrement_active_clients();
self.stats.specific.decrement_active_clients();
}
}

Expand Down Expand Up @@ -130,7 +133,7 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> Future for SubscriberSession<Io> {

match this.conn.start_send_unpin(msg) {
Ok(_) => {
this.state.stats.specific.increment_tx(msg_len);
this.stats.specific.increment_tx(msg_len);
}
Err(e) => {
error!(err = ?e, "Failed to send message to socket");
Expand Down
10 changes: 8 additions & 2 deletions msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{net::SocketAddr, path::PathBuf, sync::Arc};

use arc_swap::Guard;
use bytes::Bytes;
use futures::stream::FuturesUnordered;
use tokio::{
Expand Down Expand Up @@ -28,7 +29,7 @@ pub struct PubSocket<T: Transport<A>, A: Address> {
/// The reply socket options, shared with the driver.
options: Arc<PubOptions>,
/// The reply socket state, shared with the driver.
state: Arc<SocketState>,
state: Arc<SocketState<T::Stats>>,
/// The transport used by this socket. This value is temporary and will be moved
/// to the driver task once the socket is bound.
transport: Option<T>,
Expand Down Expand Up @@ -89,7 +90,7 @@ where
to_sessions_bcast: None,
options: Arc::new(options),
transport: Some(transport),
state: Arc::new(SocketState::default()),
state: Arc::new(SocketState::<T::Stats>::default()),
hook: None,
compressor: None,
}
Expand Down Expand Up @@ -212,6 +213,11 @@ where
&self.state.stats.specific
}

/// Get the latest transport-level stats snapshot.
pub fn transport_stats(&self) -> Guard<Arc<T::Stats>> {
self.state.transport_stats.load()
}

/// Returns the local address this socket is bound to. `None` if the socket is not bound.
pub fn local_addr(&self) -> Option<&A> {
self.local_addr.as_ref()
Expand Down
Loading