Skip to content

Commit

Permalink
connmgr: add ability to respond via router instead of pub
Browse files Browse the repository at this point in the history
  • Loading branch information
jkarneges committed Oct 2, 2024
1 parent 5d60a06 commit 61d214c
Show file tree
Hide file tree
Showing 2 changed files with 405 additions and 80 deletions.
204 changes: 141 additions & 63 deletions src/connmgr/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

use crate::connmgr::batch::{Batch, BatchKey};
use crate::connmgr::connection::{
client_req_connection, client_stream_connection, ConnectionPool, StreamSharedData,
client_req_connection, client_stream_connection, make_zhttp_response, ConnectionPool,
StreamSharedData,
};
use crate::connmgr::counter::Counter;
use crate::connmgr::resolver::Resolver;
Expand Down Expand Up @@ -381,14 +382,18 @@ impl Connections {
None => return Err(()),
};

let bkey = items.batch.add(addr, false, ckey)?;
let bkey = items.batch.add(addr, cshared.router_resp(), ckey)?;

ci.batch_key = Some(bkey);

Ok(())
}

fn next_batch_message(&self, from: &str, btype: BatchType) -> Option<(usize, zmq::Message)> {
fn next_batch_message(
&self,
from: &str,
btype: BatchType,
) -> Option<(usize, Option<ArrayVec<u8, 64>>, zmq::Message)> {
let items = &mut *self.items.borrow_mut();
let nodes = &mut items.nodes;
let batch = &mut items.batch;
Expand Down Expand Up @@ -420,46 +425,30 @@ impl Connections {

assert!(count <= zhttppacket::IDS_MAX);

let zreq = zhttppacket::Request {
let zresp = zhttppacket::Response {
from: from.as_bytes(),
ids: group.ids(),
multi: true,
ptype: match btype {
BatchType::KeepAlive => zhttppacket::RequestPacket::KeepAlive,
BatchType::Cancel => zhttppacket::RequestPacket::Cancel,
BatchType::KeepAlive => zhttppacket::ResponsePacket::KeepAlive,
BatchType::Cancel => zhttppacket::ResponsePacket::Cancel,
},
ptype_str: "",
};

let mut data = [0; BULK_PACKET_SIZE_MAX];

let size = match zreq.serialize(&mut data) {
Ok(size) => size,
Err(e) => {
error!(
"failed to serialize keep-alive packet with {} ids: {}",
zreq.ids.len(),
e
);
continue;
}
};

let data = &data[..size];
let mut scratch = [0; BULK_PACKET_SIZE_MAX];

let addr = group.addr();

let msg = {
let mut v = vec![0; addr.len() + 1 + data.len()];

v[..addr.len()].copy_from_slice(addr);
v[addr.len()] = b' ';
let pos = addr.len() + 1;
v[pos..(pos + data.len())].copy_from_slice(data);

// this takes over the vec's memory without copying
zmq::Message::from(v)
};
let (addr, msg) =
match make_zhttp_response(group.addr(), group.use_router(), zresp, &mut scratch) {
Ok(resp) => resp,
Err(e) => {
error!(
"failed to serialize keep-alive packet with {} ids: {}",
count, e
);
continue;
}
};

drop(group);

Expand All @@ -471,7 +460,7 @@ impl Connections {
ci.batch_key = None;
}

return Some((count, msg));
return Some((count, addr, msg));
}

None
Expand Down Expand Up @@ -791,7 +780,7 @@ impl Worker {
}
}

while let Some((count, msg)) =
while let Some((count, addr, msg)) =
stream_conns.next_batch_message(&instance_id, BatchType::Cancel)
{
debug!(
Expand All @@ -800,7 +789,7 @@ impl Worker {
);

match select_2(
pin!(stream_handle.send(None, msg)),
pin!(stream_handle.send(addr, msg)),
shutdown_timeout.elapsed(),
)
.await
Expand Down Expand Up @@ -1541,7 +1530,7 @@ impl Worker {
};

// there could be no message if items removed or message construction failed
let (count, msg) =
let (count, addr, msg) =
match conns.next_batch_message(&instance_id, BatchType::KeepAlive) {
Some(ret) => ret,
None => continue,
Expand All @@ -1552,7 +1541,7 @@ impl Worker {
id, count
);

if let Err(e) = send.try_send((None, msg)) {
if let Err(e) = send.try_send((addr, msg)) {
error!("zhttp write error: {}", e);
}
}
Expand Down Expand Up @@ -1925,13 +1914,19 @@ impl TestClient {
}

pub fn do_stream_http(&self, addr: std::net::SocketAddr) {
let msg = self.make_stream_message(addr, false).unwrap();
let msg = self.make_stream_message(addr, false, false).unwrap();

self.control.send(ControlMessage::Stream(msg)).unwrap();
}

pub fn do_stream_http_router_resp(&self, addr: std::net::SocketAddr) {
let msg = self.make_stream_message(addr, true, false).unwrap();

self.control.send(ControlMessage::Stream(msg)).unwrap();
}

pub fn do_stream_ws(&self, addr: std::net::SocketAddr) {
let msg = self.make_stream_message(addr, true).unwrap();
let msg = self.make_stream_message(addr, false, true).unwrap();

self.control.send(ControlMessage::Stream(msg)).unwrap();
}
Expand Down Expand Up @@ -1999,6 +1994,7 @@ impl TestClient {
fn make_stream_message(
&self,
addr: std::net::SocketAddr,
router_resp: bool,
ws: bool,
) -> Result<zmq::Message, io::Error> {
let mut dest = [0; 1024];
Expand Down Expand Up @@ -2058,6 +2054,11 @@ impl TestClient {
w.write_string(b"credits")?;
w.write_int(1024)?;

if router_resp {
w.write_string(b"router-resp")?;
w.write_bool(true)?;
}

w.end_map()?;

w.flush()?;
Expand Down Expand Up @@ -2131,6 +2132,7 @@ impl TestClient {
out_sock.connect("inproc://client-test-out").unwrap();

let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap();
out_stream_sock.set_identity(b"handler").unwrap();
out_stream_sock
.connect("inproc://client-test-out-stream")
.unwrap();
Expand Down Expand Up @@ -2164,14 +2166,22 @@ impl TestClient {

poller
.register(
&mut SourceFd(&in_sock.get_fd().unwrap()),
&mut SourceFd(&out_stream_sock.get_fd().unwrap()),
mio::Token(3),
mio::Interest::READABLE,
)
.unwrap();

let mut req_events = req_sock.get_events().unwrap();
poller
.register(
&mut SourceFd(&in_sock.get_fd().unwrap()),
mio::Token(4),
mio::Interest::READABLE,
)
.unwrap();

let mut req_events = req_sock.get_events().unwrap();
let mut out_stream_events = out_stream_sock.get_events().unwrap();
let mut in_events = in_sock.get_events().unwrap();

'main: loop {
Expand Down Expand Up @@ -2221,6 +2231,8 @@ impl TestClient {
}
}

debug!("received req message");

assert_eq!(ptype, "");
assert_eq!(code, 200);
assert_eq!(reason, "OK");
Expand All @@ -2229,32 +2241,55 @@ impl TestClient {
status.send(StatusMessage::ReqFinished).unwrap();
}

while in_events.contains(zmq::POLLIN) {
let parts = match in_sock.recv_multipart(zmq::DONTWAIT) {
Ok(parts) => parts,
Err(zmq::Error::EAGAIN) => {
in_events = in_sock.get_events().unwrap();
break;
while out_stream_events.contains(zmq::POLLIN) || in_events.contains(zmq::POLLIN) {
let mut msg_and_pos = None;

if out_stream_events.contains(zmq::POLLIN) {
match out_stream_sock.recv_multipart(zmq::DONTWAIT) {
Ok(mut parts) => {
out_stream_events = out_stream_sock.get_events().unwrap();

assert_eq!(parts.len(), 3);

msg_and_pos = Some((parts.remove(2), 0));
}
Err(zmq::Error::EAGAIN) => {
out_stream_events = out_stream_sock.get_events().unwrap();
}
Err(e) => panic!("recv error: {:?}", e),
}
Err(e) => panic!("recv error: {:?}", e),
};
}

in_events = in_sock.get_events().unwrap();
if msg_and_pos.is_none() && in_events.contains(zmq::POLLIN) {
match in_sock.recv_multipart(zmq::DONTWAIT) {
Ok(mut parts) => {
in_events = in_sock.get_events().unwrap();

assert_eq!(parts.len(), 1);
assert_eq!(parts.len(), 1);

let buf = &parts[0];
let buf = &parts[0];

let mut pos = None;
for (i, b) in buf.iter().enumerate() {
if *b == b' ' {
pos = Some(i);
break;
}
let mut pos = None;
for (i, b) in buf.iter().enumerate() {
if *b == b' ' {
pos = Some(i);
break;
}
}

msg_and_pos = Some((parts.remove(0), pos.unwrap() + 1));
}
Err(zmq::Error::EAGAIN) => {
in_events = in_sock.get_events().unwrap();
}
Err(e) => panic!("recv error: {:?}", e),
};
}

let pos = pos.unwrap();
let msg = &buf[(pos + 1)..];
let (msg, from_router) = match &msg_and_pos {
Some((msg, pos)) => (&msg[*pos..], *pos == 0),
None => break,
};

assert_eq!(msg[0], b'T');

Expand Down Expand Up @@ -2309,6 +2344,11 @@ impl TestClient {

let seq = seq.unwrap() + 1;

debug!(
"received stream message from_router={} id={} seq={}",
from_router, id, seq
);

// as a hack to make the test server stateless, respond to every message
// using the received sequence number. for messages we don't care about,
// respond with keep-alive in order to keep the sequencing going
Expand Down Expand Up @@ -2415,7 +2455,8 @@ impl TestClient {
}
}
mio::Token(2) => req_events = req_sock.get_events().unwrap(),
mio::Token(3) => in_events = in_sock.get_events().unwrap(),
mio::Token(3) => out_stream_events = out_stream_sock.get_events().unwrap(),
mio::Token(4) => in_events = in_sock.get_events().unwrap(),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -2566,6 +2607,43 @@ pub mod tests {

client.wait_stream();

// stream (http) with responses via router

client.do_stream_http_router_resp(addr);
let (mut stream, _) = listener.accept().unwrap();

let mut buf = Vec::new();
let mut req_end = 0;

while req_end == 0 {
let mut chunk = [0; 1024];
let size = stream.read(&mut chunk).unwrap();
buf.extend_from_slice(&chunk[..size]);

for i in 0..(buf.len() - 3) {
if &buf[i..(i + 4)] == b"\r\n\r\n" {
req_end = i + 4;
break;
}
}
}

let expected = format!(
concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"),
addr
);

assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected);

stream
.write(
b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n",
)
.unwrap();
drop(stream);

client.wait_stream();

// stream (ws)

client.do_stream_ws(addr);
Expand Down
Loading

0 comments on commit 61d214c

Please sign in to comment.