diff --git a/pingora-core/src/protocols/l4/socket.rs b/pingora-core/src/protocols/l4/socket.rs index 186fbecb3..f5da48c28 100644 --- a/pingora-core/src/protocols/l4/socket.rs +++ b/pingora-core/src/protocols/l4/socket.rs @@ -174,14 +174,20 @@ impl std::str::FromStr for SocketAddr { type Err = Box; // This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path - // TODO: require UDS to have some prefix fn from_str(s: &str) -> Result { - match StdSockAddr::from_str(s) { - Ok(addr) => Ok(SocketAddr::Inet(addr)), - Err(_) => { - let uds_socket = StdUnixSockAddr::from_pathname(s) - .or_err(crate::BindError, "invalid UDS path")?; - Ok(SocketAddr::Unix(uds_socket)) + if s.starts_with("unix:") { + let path = s.trim_start_matches("unix:"); + let uds_socket = StdUnixSockAddr::from_pathname(path) + .or_err(crate::BindError, "invalid UDS path")?; + Ok(SocketAddr::Unix(uds_socket)) + } else { + match StdSockAddr::from_str(s) { + Ok(addr) => Ok(SocketAddr::Inet(addr)), + Err(_) => { + let uds_socket = StdUnixSockAddr::from_pathname(s) + .or_err(crate::BindError, "invalid UDS path")?; + Ok(SocketAddr::Unix(uds_socket)) + } } } } diff --git a/pingora-ketama/Cargo.toml b/pingora-ketama/Cargo.toml index 1e467b59d..4bd54a122 100644 --- a/pingora-ketama/Cargo.toml +++ b/pingora-ketama/Cargo.toml @@ -11,6 +11,7 @@ keywords = ["hash", "hashing", "consistent", "pingora"] [dependencies] crc32fast = "1.3" +pingora-core = { version = "0.1.0", path = "../pingora-core" } [dev-dependencies] criterion = "0.4" diff --git a/pingora-ketama/examples/health_aware_selector.rs b/pingora-ketama/examples/health_aware_selector.rs index 938c7d703..315a65a52 100644 --- a/pingora-ketama/examples/health_aware_selector.rs +++ b/pingora-ketama/examples/health_aware_selector.rs @@ -1,7 +1,8 @@ use log::info; use pingora_ketama::{Bucket, Continuum}; use std::collections::HashMap; -use std::net::SocketAddr; +// use std::net::SocketAddr; +use pingora_core::protocols::l4::socket::SocketAddr; // A repository for node healthiness, emulating a health checker. struct NodeHealthRepository { @@ -50,7 +51,7 @@ impl<'a> HealthAwareNodeSelector<'a> { } if self.node_health_repo.node_is_healthy(node) { - return Some(*node); + return Some(node.clone()); } } @@ -83,9 +84,19 @@ fn main() { for i in 0..5 { let key = format!("key_{i}"); match health_aware_selector.try_select(&key) { - Some(node) => { - info!("{key}: {}:{}", node.ip(), node.port()); - } + Some(node) => match node { + SocketAddr::Inet(socket_addr) => { + info!("{key}: {}:{}", socket_addr.ip(), socket_addr.port()); + } + SocketAddr::Unix(uds) => { + if let Some(path) = uds.as_pathname() { + let path_str = path.to_string_lossy(); + info!("{key}: {}", path_str); + } else { + info!("{key}: {}", ""); + } + } + }, None => { info!("{key}: no healthy node found!"); } diff --git a/pingora-ketama/src/lib.rs b/pingora-ketama/src/lib.rs index 07877fe3a..98b7c96b1 100644 --- a/pingora-ketama/src/lib.rs +++ b/pingora-ketama/src/lib.rs @@ -26,6 +26,7 @@ //! //! ``` //! use pingora_ketama::{Bucket, Continuum}; +//! use pingora_core::protocols::l4::socket::SocketAddr; //! //! # #[allow(clippy::needless_doctest_main)] //! fn main() { @@ -39,8 +40,11 @@ //! // Let's see what the result is for a few keys: //! for key in &["some_key", "another_key", "last_key"] { //! let node = ring.node(key.as_bytes()).unwrap(); -//! println!("{}: {}:{}", key, node.ip(), node.port()); -//! } +//! match node { +//! SocketAddr::Inet(addr) => println!("{}: {}:{}", key, addr.ip(), addr.port()), +//! _ => panic!("Expected Inet address"), +//! } +//! } //! } //! ``` //! @@ -59,10 +63,11 @@ use std::cmp::Ordering; use std::io::Write; -use std::net::SocketAddr; + use std::usize; use crc32fast::Hasher; +use pingora_core::protocols::l4::socket::SocketAddr; /// A [Bucket] represents a server for consistent hashing /// @@ -70,7 +75,6 @@ use crc32fast::Hasher; #[derive(Clone, Debug, Eq, PartialEq, PartialOrd)] pub struct Bucket { // The node name. - // TODO: UDS node: SocketAddr, // The weight associated with a node. A higher weight indicates that this node should @@ -88,7 +92,6 @@ impl Bucket { /// This will panic if the weight is zero. pub fn new(node: SocketAddr, weight: u32) -> Self { assert!(weight != 0, "weight must be at least one"); - Bucket { node, weight } } } @@ -151,26 +154,41 @@ impl Continuum { for bucket in buckets { let mut hasher = Hasher::new(); - // We only do the following for backwards compatibility with nginx/memcache: - // - Convert SocketAddr to string - // - The hash input is as follows "HOST EMPTY PORT PREVIOUS_HASH". Spaces are only added - // for readability. - // TODO: remove this logic and hash the literal SocketAddr once we no longer - // need backwards compatibility + let hash_input = match &bucket.node { + SocketAddr::Inet(socket_addr) => { + // We only do the following for backwards compatibility with nginx/memcache: + // - Convert SocketAddr to string + // - The hash input is as follows "HOST EMPTY PORT PREVIOUS_HASH". Spaces are only added + // for readability. + // TODO: remove this logic and hash the literal SocketAddr once we no longer + // need backwards compatibility + + // with_capacity = max_len(ipv6)(39) + len(null)(1) + max_len(port)(5) + // 39 for IPv6, 1 for separator, 5 for port + let mut hash_bytes = Vec::with_capacity(39 + 1 + 5); + write!(&mut hash_bytes, "{}", socket_addr.ip()).unwrap(); + write!(&mut hash_bytes, "\0").unwrap(); + write!(&mut hash_bytes, "{}", socket_addr.port()).unwrap(); + hash_bytes + } + SocketAddr::Unix(uds) => { + if let Some(path) = uds.as_pathname() { + let path_str = path.to_string_lossy(); + path_str.into_owned().into_bytes() + } else { + panic!("Unable to handle Unix socket address without a valid path"); + } + } + }; - // with_capacity = max_len(ipv6)(39) + len(null)(1) + max_len(port)(5) - let mut hash_bytes = Vec::with_capacity(39 + 1 + 5); - write!(&mut hash_bytes, "{}", bucket.node.ip()).unwrap(); - write!(&mut hash_bytes, "\0").unwrap(); - write!(&mut hash_bytes, "{}", bucket.node.port()).unwrap(); - hasher.update(hash_bytes.as_ref()); + hasher.update(hash_input.as_ref()); // A higher weight will add more points for this node. let num_points = bucket.weight * POINT_MULTIPLE; // This is appended to the crc32 hash for each point. let mut prev_hash: u32 = 0; - addrs.push(bucket.node); + addrs.push(bucket.node.clone()); let node = addrs.len() - 1; for _ in 0..num_points { let mut hasher = hasher.clone(); @@ -212,10 +230,10 @@ impl Continuum { } /// Hash the given `hash_key` to the server address. - pub fn node(&self, hash_key: &[u8]) -> Option { + pub fn node(&self, hash_key: &[u8]) -> Option<&SocketAddr> { self.ring - .get(self.node_idx(hash_key)) // should we unwrap here? - .map(|p| self.addrs[p.node as usize]) + .get(self.node_idx(hash_key)) + .map(|p| &self.addrs[p.node as usize]) } /// Get an iterator of nodes starting at the original hashed node of the `hash_key`. @@ -231,11 +249,13 @@ impl Continuum { pub fn get_addr(&self, idx: &mut usize) -> Option<&SocketAddr> { let point = self.ring.get(*idx); - if point.is_some() { + if let Some(point) = point { // only update idx for non-empty ring otherwise we will panic on modulo 0 *idx = (*idx + 1) % self.ring.len(); + Some(&self.addrs[point.node as usize]) + } else { + None } - point.map(|p| &self.addrs[p.node as usize]) } } @@ -255,20 +275,35 @@ impl<'a> Iterator for NodeIterator<'a> { #[cfg(test)] mod tests { - use std::net::SocketAddr; + use std::path::Path; - use super::{Bucket, Continuum}; + use super::{Bucket, Continuum, SocketAddr}; fn get_sockaddr(ip: &str) -> SocketAddr { ip.parse().unwrap() } + fn get_uds_addr(path: &str) -> SocketAddr { + path.parse().unwrap() + } + #[test] fn consistency_after_adding_host() { fn assert_hosts(c: &Continuum) { - assert_eq!(c.node(b"a"), Some(get_sockaddr("127.0.0.10:6443"))); - assert_eq!(c.node(b"b"), Some(get_sockaddr("127.0.0.5:6443"))); + match c.node(b"a") { + Some(SocketAddr::Inet(addr)) => { + assert_eq!(*addr, "127.0.0.10:6443".parse().unwrap()) + } + _ => panic!("Expected Inet addr"), + } + + match c.node(b"b") { + Some(SocketAddr::Inet(addr)) => { + assert_eq!(*addr, "127.0.0.5:6443".parse().unwrap()) + } + _ => panic!("Expected Inet addr"), + } } let buckets: Vec<_> = (1..11) @@ -286,6 +321,23 @@ mod tests { assert_hosts(&c); } + #[test] + fn parse_ip() { + let ip: SocketAddr = "127.0.0.1:80".parse().unwrap(); + assert!(ip.as_inet().is_some()); + } + + #[test] + fn test_uds() { + let _ = "unix:/tmp/sock".parse::().unwrap().to_string(); + + assert_eq!( + "unix:/tmp/sock".parse::().unwrap().to_string(), + "/tmp/sock" + ); + assert!("/tmp/sock".parse::().is_ok()); + } + #[test] fn matches_nginx_sample() { let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7778"]; @@ -298,23 +350,58 @@ mod tests { let c = Continuum::new(&buckets); - assert_eq!(c.node(b"/some/path"), Some(get_sockaddr("127.0.0.1:7778"))); + assert_eq!(c.node(b"/some/path"), Some(&get_sockaddr("127.0.0.1:7778"))); assert_eq!( c.node(b"/some/longer/path"), - Some(get_sockaddr("127.0.0.1:7777")) + Some(&get_sockaddr("127.0.0.1:7777")) ); assert_eq!( c.node(b"/sad/zaidoon"), - Some(get_sockaddr("127.0.0.1:7778")) + Some(&get_sockaddr("127.0.0.1:7778")) ); - assert_eq!(c.node(b"/g"), Some(get_sockaddr("127.0.0.1:7777"))); + assert_eq!(c.node(b"/g"), Some(&get_sockaddr("127.0.0.1:7777"))); assert_eq!( c.node(b"/pingora/team/is/cool/and/this/is/a/long/uri"), - Some(get_sockaddr("127.0.0.1:7778")) + Some(&get_sockaddr("127.0.0.1:7778")) ); assert_eq!( c.node(b"/i/am/not/confident/in/this/code"), - Some(get_sockaddr("127.0.0.1:7777")) + Some(&get_sockaddr("127.0.0.1:7777")) + ); + } + + #[test] + fn matches_nginx_sample_uds() { + let upstream_hosts = ["unix:/tmp/uds1.sock", "unix:/tmp/uds2.sock"]; + let upstream_hosts = upstream_hosts.iter().map(|&path| get_uds_addr(path)); + + let mut buckets = Vec::new(); + for upstream in upstream_hosts { + buckets.push(Bucket::new(upstream, 1)); + } + + let c = Continuum::new(&buckets); + + // Assuming Continuum::node has been implemented to return the correct nodes + // for UDS paths. These assertions need to be adjusted based on the actual + // hash distribution logic implemented in your Continuum. + assert_eq!(c.node(b"/some/path"), Some(&get_uds_addr("/tmp/uds2.sock"))); + assert_eq!( + c.node(b"/some/longer/path"), + Some(&get_uds_addr("/tmp/uds1.sock")) + ); + assert_eq!( + c.node(b"/sad/zaidoon"), + Some(&get_uds_addr("/tmp/uds2.sock")) + ); + assert_eq!(c.node(b"/g"), Some(&get_uds_addr("/tmp/uds1.sock"))); + assert_eq!( + c.node(b"/pingora/team/is/cool/and/this/is/a/long/uri"), + Some(&get_uds_addr("/tmp/uds2.sock")) + ); + assert_eq!( + c.node(b"/i/am/not/confident/in/this/code/and/this/is/a/long/uri"), + Some(&get_uds_addr("/tmp/uds1.sock")) ); } @@ -355,7 +442,7 @@ mod tests { let upstream = pair.get(1).unwrap(); let got = c.node(uri.as_bytes()).unwrap(); - assert_eq!(got, get_sockaddr(upstream)); + assert_eq!(got, &get_sockaddr(upstream)); } } diff --git a/pingora-load-balancing/src/selection/consistent.rs b/pingora-load-balancing/src/selection/consistent.rs index 9c627260c..ed04fb14d 100644 --- a/pingora-load-balancing/src/selection/consistent.rs +++ b/pingora-load-balancing/src/selection/consistent.rs @@ -32,14 +32,7 @@ impl BackendSelection for KetamaHashing { fn build(backends: &BTreeSet) -> Self { let buckets: Vec<_> = backends .iter() - .filter_map(|b| { - // FIXME: ketama only supports Inet addr, UDS addrs are ignored here - if let SocketAddr::Inet(addr) = b.addr { - Some(Bucket::new(addr, b.weight as u32)) - } else { - None - } - }) + .map(|b| Bucket::new(b.addr.clone(), b.weight as u32)) .collect(); let new_backends = backends .iter() @@ -67,10 +60,10 @@ pub struct OwnedNodeIterator { impl BackendIter for OwnedNodeIterator { fn next(&mut self) -> Option<&Backend> { - self.ring.ring.get_addr(&mut self.idx).and_then(|addr| { - let addr = SocketAddr::Inet(*addr); - self.ring.backends.get(&addr) - }) + self.ring + .ring + .get_addr(&mut self.idx) + .and_then(|addr| self.ring.backends.get(addr)) } }