Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zhttpsocket: for server mode, add ability to respond via router socket #48073

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/connmgr/client.rs
Original file line number Diff line number Diff line change
@@ -944,7 +944,12 @@ impl Worker {
id, count
);

match select_2(pin!(stream_handle.send(msg)), shutdown_timeout.elapsed()).await {
match select_2(
pin!(stream_handle.send(None, msg)),
shutdown_timeout.elapsed(),
)
.await
{
Select2::R1(r) => r.unwrap(),
Select2::R2(_) => break 'outer,
}
@@ -1238,7 +1243,7 @@ impl Worker {
Select6::R1(_) => break,
// receiver_recv
Select6::R2(result) => match result {
Ok(msg) => handle_send.set(Some(stream_handle.send(msg))),
Ok(msg) => handle_send.set(Some(stream_handle.send(None, msg))),
Err(e) => panic!("zstream_out_receiver channel error: {}", e),
},
// handle_send
135 changes: 108 additions & 27 deletions src/connmgr/zhttpsocket.rs
Original file line number Diff line number Diff line change
@@ -429,7 +429,7 @@ struct ServerReqPipeEnd {
struct ServerStreamPipeEnd {
sender_any: channel::Sender<(arena::Arc<zmq::Message>, Session)>,
sender_direct: channel::Sender<arena::Arc<zmq::Message>>,
receiver: channel::Receiver<zmq::Message>,
receiver: channel::Receiver<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
}

struct AsyncServerReqPipeEnd {
@@ -440,7 +440,7 @@ struct AsyncServerReqPipeEnd {
struct AsyncServerStreamPipeEnd {
sender_any: AsyncSender<(arena::Arc<zmq::Message>, Session)>,
sender_direct: AsyncSender<arena::Arc<zmq::Message>>,
receiver: AsyncReceiver<zmq::Message>,
receiver: AsyncReceiver<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
}

enum ServerControlRequest {
@@ -1064,7 +1064,7 @@ enum StreamHandlesSendError {
struct ServerStreamHandles {
nodes: Slab<list::Node<ServerStreamPipe>>,
list: list::List,
recv_scratch: RefCell<RecvScratch<zmq::Message>>,
recv_scratch: RefCell<RecvScratch<(Option<ArrayVec<u8, 64>>, zmq::Message)>>,
check_send_any_scratch: RefCell<CheckSendScratch<(arena::Arc<zmq::Message>, Session)>>,
send_direct_scratch: RefCell<Vec<bool>>,
need_cleanup: Cell<bool>,
@@ -1102,7 +1102,7 @@ impl ServerStreamHandles {
}

#[allow(clippy::await_holding_refcell_ref)]
async fn recv(&self) -> zmq::Message {
async fn recv(&self) -> (Option<ArrayVec<u8, 64>>, zmq::Message) {
let mut scratch = self.recv_scratch.borrow_mut();

let (mut tasks, slice_scratch) = scratch.get();
@@ -1127,7 +1127,7 @@ impl ServerStreamHandles {

loop {
match select_slice(&mut tasks, slice_scratch).await {
(_, (_, Ok(msg))) => return msg,
(_, (_, Ok(ret))) => return ret,
(pos, (nkey, Err(mpsc::RecvError))) => {
tasks.remove(pos);

@@ -1715,10 +1715,14 @@ impl ClientSocketManager {
}
// stream_handles_recv_addr
Select9::R7((addr, msg)) => {
let h = vec![zmq::Message::from(addr.as_ref())];
let h = vec![zmq::Message::from(addr.as_slice())];

if log_enabled!(log::Level::Trace) {
trace!("OUT stream to {}", packet_to_string(&msg));
trace!(
"OUT stream to={} {}",
String::from_utf8_lossy(addr.as_slice()),
packet_to_string(&msg)
);
}

stream_out_stream_send = Some(client_stream.out_stream.send_to(h, msg));
@@ -1879,6 +1883,22 @@ impl Drop for ClientSocketManager {
}
}

enum ZmqFuture<'a> {
Send(ZmqSendFuture<'a>),
SendTo(ZmqSendToFuture<'a>),
}

impl Future for ZmqFuture<'_> {
type Output = Result<(), zmq::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match &mut *self {
Self::Send(fut) => Pin::new(fut).poll(cx),
Self::SendTo(fut) => Pin::new(fut).poll(cx),
}
}
}

pub struct ServerSocketManager {
handle_bound: usize,
thread: Option<thread::JoinHandle<()>>,
@@ -2077,6 +2097,12 @@ impl ServerSocketManager {
.inner()
.set_rcvhwm(init_hwm as i32)
.unwrap();
stream_socks
.in_stream
.inner()
.inner()
.set_sndhwm(other_hwm as i32)
.unwrap();
stream_socks
.in_stream
.inner()
@@ -2090,6 +2116,21 @@ impl ServerSocketManager {
.set_sndhwm(other_hwm as i32)
.unwrap();

stream_socks
.in_stream
.inner()
.inner()
.set_router_mandatory(true)
.unwrap();

// a ROUTER socket may still be writable after returning EAGAIN, which
// could mean that a different peer than the one we tried to write to
// is writable. there's no way to know when the desired peer will be
// writable, so we'll keep trying again after a delay
stream_socks
.in_stream
.set_retry_timeout(Some(STREAM_OUT_STREAM_DELAY));

stream_socks
.in_stream
.inner()
@@ -2101,7 +2142,7 @@ impl ServerSocketManager {
let mut stream_handles = ServerStreamHandles::new(HANDLES_MAX, sessions_max);

let mut req_send: Option<ZmqSendToFuture> = None;
let mut stream_out_send: Option<ZmqSendFuture> = None;
let mut stream_out_send: Option<ZmqFuture> = None;

let mut req_in_msg = None;
let mut stream_in_msg = None;
@@ -2285,12 +2326,27 @@ impl ServerSocketManager {
Err(e) => error!("server stream next zmq recv: {}", e),
},
// stream_handles_recv
Select10::R8(msg) => {
if log_enabled!(log::Level::Trace) {
trace!("OUT server stream {}", packet_to_string(&msg));
}
Select10::R8((addr, msg)) => {
if let Some(addr) = &addr {
let h = vec![zmq::Message::from(addr.as_ref())];

stream_out_send = Some(stream_socks.out.send(msg));
if log_enabled!(log::Level::Trace) {
trace!(
"OUT server stream to={} {}",
String::from_utf8_lossy(addr),
packet_to_string(&msg)
);
}

stream_out_send =
Some(ZmqFuture::SendTo(stream_socks.in_stream.send_to(h, msg)));
} else {
if log_enabled!(log::Level::Trace) {
trace!("OUT server stream {}", packet_to_string(&msg));
}

stream_out_send = Some(ZmqFuture::Send(stream_socks.out.send(msg)));
}
}
// stream_out_send
Select10::R9(result) => {
@@ -2528,12 +2584,12 @@ impl ClientStreamHandle {
}

pub fn send_to_addr(&self, addr: &[u8], msg: zmq::Message) -> Result<(), SendError> {
let mut a = ArrayVec::new();
if a.try_extend_from_slice(addr).is_err() {
return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput)));
}
let addr = match ArrayVec::try_from(addr) {
Ok(a) => a,
Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))),
};

match self.sender_addr.try_send((a, msg)) {
match self.sender_addr.try_send((addr, msg)) {
Ok(_) => Ok(()),
Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Disconnected(_)) => {
@@ -2648,7 +2704,7 @@ impl AsyncServerReqHandle {
}

pub struct ServerStreamHandle {
sender: channel::Sender<zmq::Message>,
sender: channel::Sender<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
receiver_any: channel::Receiver<(arena::Arc<zmq::Message>, Session)>,
receiver_direct: channel::Receiver<arena::Arc<zmq::Message>>,
}
@@ -2686,10 +2742,18 @@ impl ServerStreamHandle {
}
}

pub fn send(&self, msg: zmq::Message) -> Result<(), SendError> {
match self.sender.try_send(msg) {
pub fn send(&self, addr: Option<&[u8]>, msg: zmq::Message) -> Result<(), SendError> {
let addr = match addr {
Some(a) => match ArrayVec::try_from(a) {
Ok(a) => Some(a),
Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))),
},
None => None,
};

match self.sender.try_send((addr, msg)) {
Ok(_) => Ok(()),
Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)),
Err(mpsc::TrySendError::Disconnected(_)) => {
Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe)))
}
@@ -2698,7 +2762,7 @@ impl ServerStreamHandle {
}

pub struct AsyncServerStreamHandle {
sender: AsyncSender<zmq::Message>,
sender: AsyncSender<(Option<ArrayVec<u8, 64>>, zmq::Message)>,
receiver_any: AsyncReceiver<(arena::Arc<zmq::Message>, Session)>,
receiver_direct: AsyncReceiver<arena::Arc<zmq::Message>>,
}
@@ -2726,8 +2790,12 @@ impl AsyncServerStreamHandle {
}
}

pub async fn send(&self, msg: zmq::Message) -> Result<(), io::Error> {
match self.sender.send(msg).await {
pub async fn send(
&self,
addr: Option<ArrayVec<u8, 64>>,
msg: zmq::Message,
) -> Result<(), io::Error> {
match self.sender.send((addr, msg)).await {
Ok(_) => Ok(()),
Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
@@ -3347,6 +3415,7 @@ mod tests {
out_sock.connect("inproc://test-server-in").unwrap();

let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap();
out_stream_sock.set_identity(b"test-handler").unwrap();
out_stream_sock
.connect("inproc://test-server-in-stream")
.unwrap();
@@ -3400,13 +3469,25 @@ mod tests {
};
assert_eq!(rdata.body, b"hello");

h1.send(zmq::Message::from("test-handler world a".as_bytes()))
h1.send(None, zmq::Message::from("test-handler world a".as_bytes()))
.unwrap();

let parts = in_sock.recv_multipart(0).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0], b"test-handler world a");

// send via router
h1.send(
Some(b"test-handler"),
zmq::Message::from("world a2".as_bytes()),
)
.unwrap();

let parts = out_stream_sock.recv_multipart(0).unwrap();
assert_eq!(parts.len(), 3);
assert!(parts[1].is_empty());
assert_eq!(parts[2], b"world a2");

let req = {
let mut rdata = RequestData::new();
rdata.body = b"hello";
@@ -3449,7 +3530,7 @@ mod tests {
};
assert_eq!(rdata.body, b"hello");

h2.send(zmq::Message::from("test-handler world b".as_bytes()))
h2.send(None, zmq::Message::from("test-handler world b".as_bytes()))
.unwrap();

let parts = in_sock.recv_multipart(0).unwrap();