Skip to content

Commit

Permalink
zhttpsocket: for server mode, add ability to respond via router socket (
Browse files Browse the repository at this point in the history
  • Loading branch information
jkarneges authored Sep 30, 2024
1 parent fe718eb commit 5d54f4f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 29 deletions.
9 changes: 7 additions & 2 deletions src/connmgr/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,12 @@ impl Worker {
id, count
);

match select_2(pin!(stream_handle.send(msg)), shutdown_timeout.elapsed()).await {
match select_2(
pin!(stream_handle.send(None, msg)),
shutdown_timeout.elapsed(),
)
.await
{
Select2::R1(r) => r.unwrap(),
Select2::R2(_) => break 'outer,
}
Expand Down Expand Up @@ -1252,7 +1257,7 @@ impl Worker {
Select6::R1(_) => break,
// receiver_recv
Select6::R2(result) => match result {
Ok(msg) => handle_send.set(Some(stream_handle.send(msg))),
Ok(msg) => handle_send.set(Some(stream_handle.send(None, msg))),
Err(e) => panic!("zstream_out_receiver channel error: {}", e),
},
// handle_send
Expand Down
135 changes: 108 additions & 27 deletions src/connmgr/zhttpsocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ struct ServerReqPipeEnd {
struct ServerStreamPipeEnd {
sender_any: channel::Sender<(arena::Arc<zmq::Message>, Session)>,
sender_direct: channel::Sender<arena::Arc<zmq::Message>>,
receiver: channel::Receiver<zmq::Message>,
receiver: channel::Receiver<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
}

struct AsyncServerReqPipeEnd {
Expand All @@ -440,7 +440,7 @@ struct AsyncServerReqPipeEnd {
struct AsyncServerStreamPipeEnd {
sender_any: AsyncSender<(arena::Arc<zmq::Message>, Session)>,
sender_direct: AsyncSender<arena::Arc<zmq::Message>>,
receiver: AsyncReceiver<zmq::Message>,
receiver: AsyncReceiver<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
}

enum ServerControlRequest {
Expand Down Expand Up @@ -1064,7 +1064,7 @@ enum StreamHandlesSendError {
struct ServerStreamHandles {
nodes: Slab<list::Node<ServerStreamPipe>>,
list: list::List,
recv_scratch: RefCell<RecvScratch<zmq::Message>>,
recv_scratch: RefCell<RecvScratch<(Option<ArrayVec<u8, 64>>, zmq::Message)>>,
check_send_any_scratch: RefCell<CheckSendScratch<(arena::Arc<zmq::Message>, Session)>>,
send_direct_scratch: RefCell<Vec<bool>>,
need_cleanup: Cell<bool>,
Expand Down Expand Up @@ -1102,7 +1102,7 @@ impl ServerStreamHandles {
}

#[allow(clippy::await_holding_refcell_ref)]
async fn recv(&self) -> zmq::Message {
async fn recv(&self) -> (Option<ArrayVec<u8, 64>>, zmq::Message) {
let mut scratch = self.recv_scratch.borrow_mut();

let (mut tasks, slice_scratch) = scratch.get();
Expand All @@ -1127,7 +1127,7 @@ impl ServerStreamHandles {

loop {
match select_slice(&mut tasks, slice_scratch).await {
(_, (_, Ok(msg))) => return msg,
(_, (_, Ok(ret))) => return ret,
(pos, (nkey, Err(mpsc::RecvError))) => {
tasks.remove(pos);

Expand Down Expand Up @@ -1715,10 +1715,14 @@ impl ClientSocketManager {
}
// stream_handles_recv_addr
Select9::R7((addr, msg)) => {
let h = vec![zmq::Message::from(addr.as_ref())];
let h = vec![zmq::Message::from(addr.as_slice())];

if log_enabled!(log::Level::Trace) {
trace!("OUT stream to {}", packet_to_string(&msg));
trace!(
"OUT stream to={} {}",
String::from_utf8_lossy(addr.as_slice()),
packet_to_string(&msg)
);
}

stream_out_stream_send = Some(client_stream.out_stream.send_to(h, msg));
Expand Down Expand Up @@ -1879,6 +1883,22 @@ impl Drop for ClientSocketManager {
}
}

enum ZmqFuture<'a> {
Send(ZmqSendFuture<'a>),
SendTo(ZmqSendToFuture<'a>),
}

impl Future for ZmqFuture<'_> {
type Output = Result<(), zmq::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match &mut *self {
Self::Send(fut) => Pin::new(fut).poll(cx),
Self::SendTo(fut) => Pin::new(fut).poll(cx),
}
}
}

pub struct ServerSocketManager {
handle_bound: usize,
thread: Option<thread::JoinHandle<()>>,
Expand Down Expand Up @@ -2077,6 +2097,12 @@ impl ServerSocketManager {
.inner()
.set_rcvhwm(init_hwm as i32)
.unwrap();
stream_socks
.in_stream
.inner()
.inner()
.set_sndhwm(other_hwm as i32)
.unwrap();
stream_socks
.in_stream
.inner()
Expand All @@ -2090,6 +2116,21 @@ impl ServerSocketManager {
.set_sndhwm(other_hwm as i32)
.unwrap();

stream_socks
.in_stream
.inner()
.inner()
.set_router_mandatory(true)
.unwrap();

// a ROUTER socket may still be writable after returning EAGAIN, which
// could mean that a different peer than the one we tried to write to
// is writable. there's no way to know when the desired peer will be
// writable, so we'll keep trying again after a delay
stream_socks
.in_stream
.set_retry_timeout(Some(STREAM_OUT_STREAM_DELAY));

stream_socks
.in_stream
.inner()
Expand All @@ -2101,7 +2142,7 @@ impl ServerSocketManager {
let mut stream_handles = ServerStreamHandles::new(HANDLES_MAX, sessions_max);

let mut req_send: Option<ZmqSendToFuture> = None;
let mut stream_out_send: Option<ZmqSendFuture> = None;
let mut stream_out_send: Option<ZmqFuture> = None;

let mut req_in_msg = None;
let mut stream_in_msg = None;
Expand Down Expand Up @@ -2285,12 +2326,27 @@ impl ServerSocketManager {
Err(e) => error!("server stream next zmq recv: {}", e),
},
// stream_handles_recv
Select10::R8(msg) => {
if log_enabled!(log::Level::Trace) {
trace!("OUT server stream {}", packet_to_string(&msg));
}
Select10::R8((addr, msg)) => {
if let Some(addr) = &addr {
let h = vec![zmq::Message::from(addr.as_ref())];

stream_out_send = Some(stream_socks.out.send(msg));
if log_enabled!(log::Level::Trace) {
trace!(
"OUT server stream to={} {}",
String::from_utf8_lossy(addr),
packet_to_string(&msg)
);
}

stream_out_send =
Some(ZmqFuture::SendTo(stream_socks.in_stream.send_to(h, msg)));
} else {
if log_enabled!(log::Level::Trace) {
trace!("OUT server stream {}", packet_to_string(&msg));
}

stream_out_send = Some(ZmqFuture::Send(stream_socks.out.send(msg)));
}
}
// stream_out_send
Select10::R9(result) => {
Expand Down Expand Up @@ -2528,12 +2584,12 @@ impl ClientStreamHandle {
}

pub fn send_to_addr(&self, addr: &[u8], msg: zmq::Message) -> Result<(), SendError> {
let mut a = ArrayVec::new();
if a.try_extend_from_slice(addr).is_err() {
return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput)));
}
let addr = match ArrayVec::try_from(addr) {
Ok(a) => a,
Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))),
};

match self.sender_addr.try_send((a, msg)) {
match self.sender_addr.try_send((addr, msg)) {
Ok(_) => Ok(()),
Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Disconnected(_)) => {
Expand Down Expand Up @@ -2648,7 +2704,7 @@ impl AsyncServerReqHandle {
}

pub struct ServerStreamHandle {
sender: channel::Sender<zmq::Message>,
sender: channel::Sender<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
receiver_any: channel::Receiver<(arena::Arc<zmq::Message>, Session)>,
receiver_direct: channel::Receiver<arena::Arc<zmq::Message>>,
}
Expand Down Expand Up @@ -2686,10 +2742,18 @@ impl ServerStreamHandle {
}
}

pub fn send(&self, msg: zmq::Message) -> Result<(), SendError> {
match self.sender.try_send(msg) {
pub fn send(&self, addr: Option<&[u8]>, msg: zmq::Message) -> Result<(), SendError> {
let addr = match addr {
Some(a) => match ArrayVec::try_from(a) {
Ok(a) => Some(a),
Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))),
},
None => None,
};

match self.sender.try_send((addr, msg)) {
Ok(_) => Ok(()),
Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Disconnected(_)) => {
Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe)))
}
Expand All @@ -2698,7 +2762,7 @@ impl ServerStreamHandle {
}

pub struct AsyncServerStreamHandle {
sender: AsyncSender<zmq::Message>,
sender: AsyncSender<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
receiver_any: AsyncReceiver<(arena::Arc<zmq::Message>, Session)>,
receiver_direct: AsyncReceiver<arena::Arc<zmq::Message>>,
}
Expand Down Expand Up @@ -2726,8 +2790,12 @@ impl AsyncServerStreamHandle {
}
}

pub async fn send(&self, msg: zmq::Message) -> Result<(), io::Error> {
match self.sender.send(msg).await {
pub async fn send(
&self,
addr: Option<ArrayVec<u8, 64>>,
msg: zmq::Message,
) -> Result<(), io::Error> {
match self.sender.send((addr, msg)).await {
Ok(_) => Ok(()),
Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
Expand Down Expand Up @@ -3347,6 +3415,7 @@ mod tests {
out_sock.connect("inproc://test-server-in").unwrap();

let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap();
out_stream_sock.set_identity(b"test-handler").unwrap();
out_stream_sock
.connect("inproc://test-server-in-stream")
.unwrap();
Expand Down Expand Up @@ -3400,13 +3469,25 @@ mod tests {
};
assert_eq!(rdata.body, b"hello");

h1.send(zmq::Message::from("test-handler world a".as_bytes()))
h1.send(None, zmq::Message::from("test-handler world a".as_bytes()))
.unwrap();

let parts = in_sock.recv_multipart(0).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0], b"test-handler world a");

// send via router
h1.send(
Some(b"test-handler"),
zmq::Message::from("world a2".as_bytes()),
)
.unwrap();

let parts = out_stream_sock.recv_multipart(0).unwrap();
assert_eq!(parts.len(), 3);
assert!(parts[1].is_empty());
assert_eq!(parts[2], b"world a2");

let req = {
let mut rdata = RequestData::new();
rdata.body = b"hello";
Expand Down Expand Up @@ -3449,7 +3530,7 @@ mod tests {
};
assert_eq!(rdata.body, b"hello");

h2.send(zmq::Message::from("test-handler world b".as_bytes()))
h2.send(None, zmq::Message::from("test-handler world b".as_bytes()))
.unwrap();

let parts = in_sock.recv_multipart(0).unwrap();
Expand Down

0 comments on commit 5d54f4f

Please sign in to comment.