From 51625dd98ef043be73c96e4ad3ff50128d0281dc Mon Sep 17 00:00:00 2001 From: Paul Horn <dev@knutwalker.engineer> Date: Wed, 31 Jul 2024 18:54:09 +0200 Subject: [PATCH 1/4] Refactor Connection constructor methods --- lib/src/connection.rs | 115 ++++++++++++++++-------------------------- lib/src/pool.rs | 2 +- 2 files changed, 44 insertions(+), 73 deletions(-) diff --git a/lib/src/connection.rs b/lib/src/connection.rs index eaeaada..10bc8e0 100644 --- a/lib/src/connection.rs +++ b/lib/src/connection.rs @@ -2,6 +2,7 @@ use crate::bolt::{ExpectedResponse, Message, MessageResponse}; use crate::{ auth::ClientCertificate, + connection::stream::ConnectionStream, errors::{Error, Result}, messages::{BoltRequest, BoltResponse, HelloBuilder}, version::Version, @@ -12,17 +13,13 @@ use log::warn; use std::fs::File; use std::io::BufReader; use std::{mem, sync::Arc}; -use stream::ConnectionStream; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufStream}, net::TcpStream, }; -use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName}; use tokio_rustls::{ - rustls::{ - pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName}, - ClientConfig, RootCertStore, - }, + rustls::{ClientConfig, RootCertStore}, TlsConnector, }; use url::{Host, Url}; @@ -36,64 +33,42 @@ pub struct Connection { } impl Connection { - pub(crate) async fn new(info: &ConnectionInfo) -> Result<Connection> { + pub(crate) fn new( + info: &ConnectionInfo, + ) -> Result<impl std::future::Future<Output = Result<Connection>>> { let mut hello_builder = HelloBuilder::new(&*info.user, &*info.password); if let Routing::Yes(routing) = &info.routing { hello_builder.with_routing(routing.clone()); }; - - let stream = match &info.host { - Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?, - Host::Ipv4(ip) => TcpStream::connect((*ip, info.port)).await?, - Host::Ipv6(ip) => TcpStream::connect((*ip, info.port)).await?, + let host = info.host.clone(); + let port = info.port; + let encryption_connector = match info.encryption { + Encryption::No => None, + Encryption::Tls => Some(Self::tls_connector( + &info.host, + info.client_certificate.as_ref(), + )?), }; - match info.encryption { - Encryption::No => Self::new_unencrypted(stream, hello_builder).await, - Encryption::Tls => { - if let Some(certificate) = info.client_certificate.as_ref() { - Self::new_tls_with_certificate(stream, &info.host, hello_builder, certificate) - .await - } else { - Self::new_tls(stream, &info.host, hello_builder).await - } - } - } - } - - async fn new_unencrypted(stream: TcpStream, hello_builder: HelloBuilder) -> Result<Connection> { - Self::init(hello_builder, stream).await - } - - async fn new_tls<T: AsRef<str>>( - stream: TcpStream, - host: &Host<T>, - hello_builder: HelloBuilder, - ) -> Result<Connection> { - let root_cert_store = Self::build_cert_store(); - let stream = Self::build_stream(stream, host, root_cert_store).await?; - - Self::init(hello_builder, stream).await + Ok(async move { + let stream = match host { + Host::Domain(domain) => TcpStream::connect((&*domain, port)).await?, + Host::Ipv4(ip) => TcpStream::connect((ip, port)).await?, + Host::Ipv6(ip) => TcpStream::connect((ip, port)).await?, + }; + + let stream: ConnectionStream = match encryption_connector { + Some((connector, domain)) => connector.connect(domain, stream).await?.into(), + None => stream.into(), + }; + Self::init(hello_builder, stream).await + }) } - async fn new_tls_with_certificate<T: AsRef<str>>( - stream: TcpStream, + fn tls_connector<T: AsRef<str>>( host: &Host<T>, - hello_builder: HelloBuilder, - certificate: &ClientCertificate, - ) -> Result<Connection> { - let mut root_cert_store = Self::build_cert_store(); - - let cert_file = File::open(certificate.cert_file.as_os_str())?; - let mut reader = BufReader::new(cert_file); - let certs = rustls_pemfile::certs(&mut reader).flatten(); - root_cert_store.add_parsable_certificates(certs); - - let stream = Self::build_stream(stream, host, root_cert_store).await?; - Self::init(hello_builder, stream).await - } - - fn build_cert_store() -> RootCertStore { + certificate: Option<&ClientCertificate>, + ) -> Result<(TlsConnector, ServerName<'static>)> { let mut root_cert_store = RootCertStore::empty(); match rustls_native_certs::load_native_certs() { Ok(certs) => { @@ -103,14 +78,14 @@ impl Connection { warn!("Failed to load native certificates: {e}"); } } - root_cert_store - } - async fn build_stream<T: AsRef<str>>( - stream: TcpStream, - host: &Host<T>, - root_cert_store: RootCertStore, - ) -> Result<TlsStream<TcpStream>, Error> { + if let Some(certificate) = certificate { + let cert_file = File::open(&certificate.cert_file)?; + let mut reader = BufReader::new(cert_file); + let certs = rustls_pemfile::certs(&mut reader).flatten(); + root_cert_store.add_parsable_certificates(certs); + } + let config = ClientConfig::builder() .with_root_certificates(root_cert_store) .with_no_client_auth(); @@ -119,21 +94,17 @@ impl Connection { let connector = TlsConnector::from(config); let domain = match host { - Host::Domain(domain) => ServerName::try_from(domain.as_ref().to_owned()) - .map_err(|_| Error::InvalidDnsName(domain.as_ref().to_owned()))?, + Host::Domain(domain) => ServerName::try_from(String::from(domain.as_ref())) + .map_err(|_| Error::InvalidDnsName(String::from(domain.as_ref())))?, Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(*ip))), Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(*ip))), }; - let stream = connector.connect(domain, stream).await?; - Ok(stream) + Ok((connector, domain)) } - async fn init( - hello_builder: HelloBuilder, - stream: impl Into<ConnectionStream>, - ) -> Result<Connection> { - let mut stream = BufStream::new(stream.into()); + async fn init(hello_builder: HelloBuilder, stream: ConnectionStream) -> Result<Connection> { + let mut stream = BufStream::new(stream); stream.write_all(&[0x60, 0x60, 0xB0, 0x17]).await?; stream.write_all(&Version::supported_versions()).await?; stream.flush().await?; @@ -266,7 +237,7 @@ pub(crate) struct ConnectionInfo { client_certificate: Option<ClientCertificate>, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum Encryption { No, Tls, diff --git a/lib/src/pool.rs b/lib/src/pool.rs index 39eda8e..ed369dc 100644 --- a/lib/src/pool.rs +++ b/lib/src/pool.rs @@ -46,7 +46,7 @@ impl Manager for ConnectionManager { async fn create(&self) -> Result<Self::Type, Self::Error> { info!("creating new connection..."); - Connection::new(&self.info).await + Connection::new(&self.info)?.await } async fn recycle(&self, obj: &mut Self::Type, _: &Metrics) -> RecycleResult<Self::Error> { From e2daa7eda19b7a5033c2bfde26dffb215eba9d21 Mon Sep 17 00:00:00 2001 From: Paul Horn <dev@knutwalker.engineer> Date: Wed, 31 Jul 2024 19:26:20 +0200 Subject: [PATCH 2/4] Build tls connector once per pool --- lib/src/connection.rs | 165 ++++++++++++++++++++++-------------------- lib/src/messages.rs | 9 ++- lib/src/pool.rs | 2 +- 3 files changed, 93 insertions(+), 83 deletions(-) diff --git a/lib/src/connection.rs b/lib/src/connection.rs index 10bc8e0..85ade6d 100644 --- a/lib/src/connection.rs +++ b/lib/src/connection.rs @@ -35,72 +35,27 @@ pub struct Connection { impl Connection { pub(crate) fn new( info: &ConnectionInfo, - ) -> Result<impl std::future::Future<Output = Result<Connection>>> { - let mut hello_builder = HelloBuilder::new(&*info.user, &*info.password); - if let Routing::Yes(routing) = &info.routing { - hello_builder.with_routing(routing.clone()); - }; + ) -> impl std::future::Future<Output = Result<Connection>> { + // we do this setup outside of the async block so that the returned future + // does not borrow the info struct and can be Send + let hello_builder = + HelloBuilder::new(&*info.user, &*info.password).with_routing(info.routing.clone()); + let encryption = info.encryption.clone(); let host = info.host.clone(); let port = info.port; - let encryption_connector = match info.encryption { - Encryption::No => None, - Encryption::Tls => Some(Self::tls_connector( - &info.host, - info.client_certificate.as_ref(), - )?), - }; - - Ok(async move { + async move { let stream = match host { Host::Domain(domain) => TcpStream::connect((&*domain, port)).await?, Host::Ipv4(ip) => TcpStream::connect((ip, port)).await?, Host::Ipv6(ip) => TcpStream::connect((ip, port)).await?, }; - let stream: ConnectionStream = match encryption_connector { + let stream: ConnectionStream = match encryption { Some((connector, domain)) => connector.connect(domain, stream).await?.into(), None => stream.into(), }; Self::init(hello_builder, stream).await - }) - } - - fn tls_connector<T: AsRef<str>>( - host: &Host<T>, - certificate: Option<&ClientCertificate>, - ) -> Result<(TlsConnector, ServerName<'static>)> { - let mut root_cert_store = RootCertStore::empty(); - match rustls_native_certs::load_native_certs() { - Ok(certs) => { - root_cert_store.add_parsable_certificates(certs); - } - Err(e) => { - warn!("Failed to load native certificates: {e}"); - } - } - - if let Some(certificate) = certificate { - let cert_file = File::open(&certificate.cert_file)?; - let mut reader = BufReader::new(cert_file); - let certs = rustls_pemfile::certs(&mut reader).flatten(); - root_cert_store.add_parsable_certificates(certs); } - - let config = ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - - let config = Arc::new(config); - let connector = TlsConnector::from(config); - - let domain = match host { - Host::Domain(domain) => ServerName::try_from(String::from(domain.as_ref())) - .map_err(|_| Error::InvalidDnsName(String::from(domain.as_ref())))?, - Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(*ip))), - Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(*ip))), - }; - - Ok((connector, domain)) } async fn init(hello_builder: HelloBuilder, stream: ConnectionStream) -> Result<Connection> { @@ -112,7 +67,7 @@ impl Connection { stream.read_exact(&mut response).await?; let version = Version::parse(response)?; let mut connection = Connection { version, stream }; - let hello = hello_builder.version(version).build(); + let hello = hello_builder.with_version(version).build(); match connection.send_recv(hello).await? { BoltResponse::Success(_msg) => Ok(connection), BoltResponse::Failure(msg) => { @@ -226,29 +181,43 @@ impl Connection { } } -#[derive(Debug)] pub(crate) struct ConnectionInfo { user: Arc<str>, password: Arc<str>, host: Host<Arc<str>>, port: u16, - encryption: Encryption, routing: Routing, - client_certificate: Option<ClientCertificate>, + encryption: Option<(TlsConnector, ServerName<'static>)>, } -#[derive(Debug, Clone, Copy)] -enum Encryption { - No, - Tls, +impl std::fmt::Debug for ConnectionInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectionInfo") + .field("user", &self.user) + .field("password", &"***") + .field("host", &self.host) + .field("port", &self.port) + .field("routing", &self.routing) + .field("encryption", &self.encryption.is_some()) + .finish_non_exhaustive() + } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum Routing { No, Yes(BoltMap), } +impl From<Routing> for Option<BoltMap> { + fn from(routing: Routing) -> Self { + match routing { + Routing::No => None, + Routing::Yes(routing) => Some(routing), + } + } +} + impl ConnectionInfo { pub(crate) fn new( uri: &str, @@ -258,25 +227,20 @@ impl ConnectionInfo { ) -> Result<Self> { let mut url = NeoUrl::parse(uri)?; - let host = url.host(); - let host = match host { - Host::Domain(s) => Host::Domain(Arc::<str>::from(s)), - Host::Ipv4(d) => Host::Ipv4(d), - Host::Ipv6(d) => Host::Ipv6(d), - }; - - let port = url.port(); - let (routing, encryption) = match url.scheme() { - "bolt" | "" => (false, Encryption::No), - "bolt+s" => (false, Encryption::Tls), - "bolt+ssc" => (false, Encryption::Tls), - "neo4j" => (true, Encryption::No), - "neo4j+s" => (true, Encryption::Tls), - "neo4j+ssc" => (true, Encryption::Tls), + "bolt" | "" => (false, false), + "bolt+s" => (false, true), + "bolt+ssc" => (false, true), + "neo4j" => (true, false), + "neo4j+s" => (true, true), + "neo4j+ssc" => (true, true), otherwise => return Err(Error::UnsupportedScheme(otherwise.to_owned())), }; + let encryption = encryption + .then(|| Self::tls_connector(url.host(), client_certificate)) + .transpose()?; + let routing = if routing { log::warn!(concat!( "This driver does not yet implement client-side routing. ", @@ -289,16 +253,59 @@ impl ConnectionInfo { url.warn_on_unexpected_components(); + let host = match url.host() { + Host::Domain(s) => Host::Domain(Arc::<str>::from(s)), + Host::Ipv4(d) => Host::Ipv4(d), + Host::Ipv6(d) => Host::Ipv6(d), + }; + Ok(Self { user: user.into(), password: password.into(), host, - port, + port: url.port(), encryption, routing, - client_certificate: client_certificate.cloned(), }) } + + fn tls_connector( + host: Host<&str>, + certificate: Option<&ClientCertificate>, + ) -> Result<(TlsConnector, ServerName<'static>)> { + let mut root_cert_store = RootCertStore::empty(); + match rustls_native_certs::load_native_certs() { + Ok(certs) => { + root_cert_store.add_parsable_certificates(certs); + } + Err(e) => { + warn!("Failed to load native certificates: {e}"); + } + } + + if let Some(certificate) = certificate { + let cert_file = File::open(&certificate.cert_file)?; + let mut reader = BufReader::new(cert_file); + let certs = rustls_pemfile::certs(&mut reader).flatten(); + root_cert_store.add_parsable_certificates(certs); + } + + let config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let config = Arc::new(config); + let connector = TlsConnector::from(config); + + let domain = match host { + Host::Domain(domain) => ServerName::try_from(domain.to_owned()) + .map_err(|_| Error::InvalidDnsName(domain.to_owned()))?, + Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(ip))), + Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(ip))), + }; + + Ok((connector, domain)) + } } struct NeoUrl(Url); diff --git a/lib/src/messages.rs b/lib/src/messages.rs index a6987fe..e09e0d5 100644 --- a/lib/src/messages.rs +++ b/lib/src/messages.rs @@ -68,11 +68,14 @@ impl HelloBuilder { } } - pub fn with_routing(&mut self, routing: BoltMap) { - self.routing = Some(routing); + pub fn with_routing(self, routing: impl Into<Option<BoltMap>>) -> Self { + Self { + routing: routing.into(), + ..self + } } - pub fn version(self, version: Version) -> Self { + pub fn with_version(self, version: Version) -> Self { Self { version, ..self } } diff --git a/lib/src/pool.rs b/lib/src/pool.rs index ed369dc..39eda8e 100644 --- a/lib/src/pool.rs +++ b/lib/src/pool.rs @@ -46,7 +46,7 @@ impl Manager for ConnectionManager { async fn create(&self) -> Result<Self::Type, Self::Error> { info!("creating new connection..."); - Connection::new(&self.info)?.await + Connection::new(&self.info).await } async fn recycle(&self, obj: &mut Self::Type, _: &Metrics) -> RecycleResult<Self::Error> { From 2b3cef295ba933a853e1b0f2f60720d059ffa228 Mon Sep 17 00:00:00 2001 From: Paul Horn <dev@knutwalker.engineer> Date: Wed, 31 Jul 2024 19:26:27 +0200 Subject: [PATCH 3/4] Use ring as crypto provider --- lib/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Cargo.toml b/lib/Cargo.toml index c7ca59b..a4eec6d 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -60,7 +60,7 @@ features = ["managed"] [dependencies.tokio-rustls] version = "0.26.0" default-features = false -features = ["tls12"] +features = ["tls12", "ring"] [dev-dependencies] From 725f8edd4046df05d577b259899d1c9755630af2 Mon Sep 17 00:00:00 2001 From: Paul Horn <dev@knutwalker.engineer> Date: Wed, 14 Aug 2024 19:10:30 +0200 Subject: [PATCH 4/4] Refactor connection creation to split up async stuff into chunks --- lib/src/connection.rs | 100 ++++++++++++++++++++++++++---------------- lib/src/version.rs | 15 +++---- 2 files changed, 69 insertions(+), 46 deletions(-) diff --git a/lib/src/connection.rs b/lib/src/connection.rs index 85ade6d..6c98201 100644 --- a/lib/src/connection.rs +++ b/lib/src/connection.rs @@ -8,18 +8,18 @@ use crate::{ version::Version, BoltMap, }; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use log::warn; -use std::fs::File; -use std::io::BufReader; -use std::{mem, sync::Arc}; +use std::{fs::File, io::BufReader, mem, sync::Arc}; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt, BufStream}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream}, net::TcpStream, }; -use tokio_rustls::rustls::pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName}; use tokio_rustls::{ - rustls::{ClientConfig, RootCertStore}, + rustls::{ + pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName}, + ClientConfig, RootCertStore, + }, TlsConnector, }; use url::{Host, Url}; @@ -33,43 +33,60 @@ pub struct Connection { } impl Connection { - pub(crate) fn new( - info: &ConnectionInfo, - ) -> impl std::future::Future<Output = Result<Connection>> { - // we do this setup outside of the async block so that the returned future - // does not borrow the info struct and can be Send - let hello_builder = - HelloBuilder::new(&*info.user, &*info.password).with_routing(info.routing.clone()); - let encryption = info.encryption.clone(); - let host = info.host.clone(); - let port = info.port; - async move { - let stream = match host { - Host::Domain(domain) => TcpStream::connect((&*domain, port)).await?, - Host::Ipv4(ip) => TcpStream::connect((ip, port)).await?, - Host::Ipv6(ip) => TcpStream::connect((ip, port)).await?, - }; - - let stream: ConnectionStream = match encryption { - Some((connector, domain)) => connector.connect(domain, stream).await?.into(), - None => stream.into(), - }; - Self::init(hello_builder, stream).await - } + pub(crate) async fn new(info: &ConnectionInfo) -> Result<Self> { + let mut connection = Self::prepare(info).await?; + let hello = info.to_hello(connection.version); + connection.hello(hello).await?; + Ok(connection) + } + + pub(crate) async fn prepare(info: &ConnectionInfo) -> Result<Self> { + let mut stream = match &info.host { + Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?, + Host::Ipv4(ip) => TcpStream::connect((*ip, info.port)).await?, + Host::Ipv6(ip) => TcpStream::connect((*ip, info.port)).await?, + }; + + Ok(match &info.encryption { + Some((connector, domain)) => { + let mut stream = connector.connect(domain.clone(), stream).await?; + let version = Self::init(&mut stream).await?; + Self::create(stream, version) + } + None => { + let version = Self::init(&mut stream).await?; + Self::create(stream, version) + } + }) } - async fn init(hello_builder: HelloBuilder, stream: ConnectionStream) -> Result<Connection> { - let mut stream = BufStream::new(stream); - stream.write_all(&[0x60, 0x60, 0xB0, 0x17]).await?; - stream.write_all(&Version::supported_versions()).await?; + async fn init<A: AsyncWrite + AsyncRead + Unpin>(stream: &mut A) -> Result<Version> { + stream.write_all_buf(&mut Self::init_msg()).await?; stream.flush().await?; + let mut response = [0, 0, 0, 0]; stream.read_exact(&mut response).await?; let version = Version::parse(response)?; - let mut connection = Connection { version, stream }; - let hello = hello_builder.with_version(version).build(); - match connection.send_recv(hello).await? { - BoltResponse::Success(_msg) => Ok(connection), + Ok(version) + } + + fn create(stream: impl Into<ConnectionStream>, version: Version) -> Connection { + Connection { + version, + stream: BufStream::new(stream.into()), + } + } + + fn init_msg() -> Bytes { + let mut init = BytesMut::with_capacity(20); + init.put_slice(&[0x60, 0x60, 0xB0, 0x17]); + Version::add_supported_versions(&mut init); + init.freeze() + } + + async fn hello(&mut self, req: BoltRequest) -> Result<()> { + match self.send_recv(req).await? { + BoltResponse::Success(_msg) => Ok(()), BoltResponse::Failure(msg) => { Err(Error::AuthenticationError(msg.get("message").unwrap())) } @@ -306,6 +323,13 @@ impl ConnectionInfo { Ok((connector, domain)) } + + pub(crate) fn to_hello(&self, version: Version) -> BoltRequest { + HelloBuilder::new(&*self.user, &*self.password) + .with_routing(self.routing.clone()) + .with_version(version) + .build() + } } struct NeoUrl(Url); diff --git a/lib/src/version.rs b/lib/src/version.rs index be5a9cf..b6cca9b 100644 --- a/lib/src/version.rs +++ b/lib/src/version.rs @@ -1,5 +1,5 @@ use crate::errors::{Error, Result}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use std::cmp::PartialEq; use std::fmt::Debug; @@ -11,13 +11,12 @@ pub enum Version { } impl Version { - pub fn supported_versions() -> Bytes { - let mut bytes = BytesMut::with_capacity(16); - let versions: [u32; 4] = [0x0104, 0x0004, 0, 0]; - for version in versions.iter() { - bytes.put_u32(*version); - } - bytes.freeze() + pub fn add_supported_versions(bytes: &mut BytesMut) { + bytes.reserve(16); + bytes.put_u32(0x0104); // V4_1 + bytes.put_u32(0x0004); // V4 + bytes.put_u32(0); + bytes.put_u32(0); } pub fn parse(version_bytes: [u8; 4]) -> Result<Version> {