diff --git a/lib/Cargo.toml b/lib/Cargo.toml index c7ca59b..a4eec6d 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -60,7 +60,7 @@ features = ["managed"] [dependencies.tokio-rustls] version = "0.26.0" default-features = false -features = ["tls12"] +features = ["tls12", "ring"] [dev-dependencies] diff --git a/lib/src/connection.rs b/lib/src/connection.rs index eaeaada..6c98201 100644 --- a/lib/src/connection.rs +++ b/lib/src/connection.rs @@ -2,22 +2,19 @@ use crate::bolt::{ExpectedResponse, Message, MessageResponse}; use crate::{ auth::ClientCertificate, + connection::stream::ConnectionStream, errors::{Error, Result}, messages::{BoltRequest, BoltResponse, HelloBuilder}, version::Version, BoltMap, }; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use log::warn; -use std::fs::File; -use std::io::BufReader; -use std::{mem, sync::Arc}; -use stream::ConnectionStream; +use std::{fs::File, io::BufReader, mem, sync::Arc}; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt, BufStream}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream}, net::TcpStream, }; -use tokio_rustls::client::TlsStream; use tokio_rustls::{ rustls::{ pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName}, @@ -36,114 +33,60 @@ pub struct Connection { } impl Connection { - pub(crate) async fn new(info: &ConnectionInfo) -> Result { - let mut hello_builder = HelloBuilder::new(&*info.user, &*info.password); - if let Routing::Yes(routing) = &info.routing { - hello_builder.with_routing(routing.clone()); - }; + pub(crate) async fn new(info: &ConnectionInfo) -> Result { + let mut connection = Self::prepare(info).await?; + let hello = info.to_hello(connection.version); + connection.hello(hello).await?; + Ok(connection) + } - let stream = match &info.host { + pub(crate) async fn prepare(info: &ConnectionInfo) -> Result { + let mut stream = match &info.host { Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?, Host::Ipv4(ip) => TcpStream::connect((*ip, info.port)).await?, Host::Ipv6(ip) => TcpStream::connect((*ip, info.port)).await?, }; - match info.encryption { - Encryption::No => Self::new_unencrypted(stream, hello_builder).await, - Encryption::Tls => { - if let Some(certificate) = info.client_certificate.as_ref() { - Self::new_tls_with_certificate(stream, &info.host, hello_builder, certificate) - .await - } else { - Self::new_tls(stream, &info.host, hello_builder).await - } + Ok(match &info.encryption { + Some((connector, domain)) => { + let mut stream = connector.connect(domain.clone(), stream).await?; + let version = Self::init(&mut stream).await?; + Self::create(stream, version) } - } - } - - async fn new_unencrypted(stream: TcpStream, hello_builder: HelloBuilder) -> Result { - Self::init(hello_builder, stream).await - } - - async fn new_tls>( - stream: TcpStream, - host: &Host, - hello_builder: HelloBuilder, - ) -> Result { - let root_cert_store = Self::build_cert_store(); - let stream = Self::build_stream(stream, host, root_cert_store).await?; - - Self::init(hello_builder, stream).await + None => { + let version = Self::init(&mut stream).await?; + Self::create(stream, version) + } + }) } - async fn new_tls_with_certificate>( - stream: TcpStream, - host: &Host, - hello_builder: HelloBuilder, - certificate: &ClientCertificate, - ) -> Result { - let mut root_cert_store = Self::build_cert_store(); - - let cert_file = File::open(certificate.cert_file.as_os_str())?; - let mut reader = BufReader::new(cert_file); - let certs = rustls_pemfile::certs(&mut reader).flatten(); - root_cert_store.add_parsable_certificates(certs); + async fn init(stream: &mut A) -> Result { + stream.write_all_buf(&mut Self::init_msg()).await?; + stream.flush().await?; - let stream = Self::build_stream(stream, host, root_cert_store).await?; - Self::init(hello_builder, stream).await + let mut response = [0, 0, 0, 0]; + stream.read_exact(&mut response).await?; + let version = Version::parse(response)?; + Ok(version) } - fn build_cert_store() -> RootCertStore { - let mut root_cert_store = RootCertStore::empty(); - match rustls_native_certs::load_native_certs() { - Ok(certs) => { - root_cert_store.add_parsable_certificates(certs); - } - Err(e) => { - warn!("Failed to load native certificates: {e}"); - } + fn create(stream: impl Into, version: Version) -> Connection { + Connection { + version, + stream: BufStream::new(stream.into()), } - root_cert_store } - async fn build_stream>( - stream: TcpStream, - host: &Host, - root_cert_store: RootCertStore, - ) -> Result, Error> { - let config = ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - - let config = Arc::new(config); - let connector = TlsConnector::from(config); - - let domain = match host { - Host::Domain(domain) => ServerName::try_from(domain.as_ref().to_owned()) - .map_err(|_| Error::InvalidDnsName(domain.as_ref().to_owned()))?, - Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(*ip))), - Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(*ip))), - }; - - let stream = connector.connect(domain, stream).await?; - Ok(stream) + fn init_msg() -> Bytes { + let mut init = BytesMut::with_capacity(20); + init.put_slice(&[0x60, 0x60, 0xB0, 0x17]); + Version::add_supported_versions(&mut init); + init.freeze() } - async fn init( - hello_builder: HelloBuilder, - stream: impl Into, - ) -> Result { - let mut stream = BufStream::new(stream.into()); - stream.write_all(&[0x60, 0x60, 0xB0, 0x17]).await?; - stream.write_all(&Version::supported_versions()).await?; - stream.flush().await?; - let mut response = [0, 0, 0, 0]; - stream.read_exact(&mut response).await?; - let version = Version::parse(response)?; - let mut connection = Connection { version, stream }; - let hello = hello_builder.version(version).build(); - match connection.send_recv(hello).await? { - BoltResponse::Success(_msg) => Ok(connection), + async fn hello(&mut self, req: BoltRequest) -> Result<()> { + match self.send_recv(req).await? { + BoltResponse::Success(_msg) => Ok(()), BoltResponse::Failure(msg) => { Err(Error::AuthenticationError(msg.get("message").unwrap())) } @@ -255,29 +198,43 @@ impl Connection { } } -#[derive(Debug)] pub(crate) struct ConnectionInfo { user: Arc, password: Arc, host: Host>, port: u16, - encryption: Encryption, routing: Routing, - client_certificate: Option, + encryption: Option<(TlsConnector, ServerName<'static>)>, } -#[derive(Debug)] -enum Encryption { - No, - Tls, +impl std::fmt::Debug for ConnectionInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectionInfo") + .field("user", &self.user) + .field("password", &"***") + .field("host", &self.host) + .field("port", &self.port) + .field("routing", &self.routing) + .field("encryption", &self.encryption.is_some()) + .finish_non_exhaustive() + } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum Routing { No, Yes(BoltMap), } +impl From for Option { + fn from(routing: Routing) -> Self { + match routing { + Routing::No => None, + Routing::Yes(routing) => Some(routing), + } + } +} + impl ConnectionInfo { pub(crate) fn new( uri: &str, @@ -287,25 +244,20 @@ impl ConnectionInfo { ) -> Result { let mut url = NeoUrl::parse(uri)?; - let host = url.host(); - let host = match host { - Host::Domain(s) => Host::Domain(Arc::::from(s)), - Host::Ipv4(d) => Host::Ipv4(d), - Host::Ipv6(d) => Host::Ipv6(d), - }; - - let port = url.port(); - let (routing, encryption) = match url.scheme() { - "bolt" | "" => (false, Encryption::No), - "bolt+s" => (false, Encryption::Tls), - "bolt+ssc" => (false, Encryption::Tls), - "neo4j" => (true, Encryption::No), - "neo4j+s" => (true, Encryption::Tls), - "neo4j+ssc" => (true, Encryption::Tls), + "bolt" | "" => (false, false), + "bolt+s" => (false, true), + "bolt+ssc" => (false, true), + "neo4j" => (true, false), + "neo4j+s" => (true, true), + "neo4j+ssc" => (true, true), otherwise => return Err(Error::UnsupportedScheme(otherwise.to_owned())), }; + let encryption = encryption + .then(|| Self::tls_connector(url.host(), client_certificate)) + .transpose()?; + let routing = if routing { log::warn!(concat!( "This driver does not yet implement client-side routing. ", @@ -318,16 +270,66 @@ impl ConnectionInfo { url.warn_on_unexpected_components(); + let host = match url.host() { + Host::Domain(s) => Host::Domain(Arc::::from(s)), + Host::Ipv4(d) => Host::Ipv4(d), + Host::Ipv6(d) => Host::Ipv6(d), + }; + Ok(Self { user: user.into(), password: password.into(), host, - port, + port: url.port(), encryption, routing, - client_certificate: client_certificate.cloned(), }) } + + fn tls_connector( + host: Host<&str>, + certificate: Option<&ClientCertificate>, + ) -> Result<(TlsConnector, ServerName<'static>)> { + let mut root_cert_store = RootCertStore::empty(); + match rustls_native_certs::load_native_certs() { + Ok(certs) => { + root_cert_store.add_parsable_certificates(certs); + } + Err(e) => { + warn!("Failed to load native certificates: {e}"); + } + } + + if let Some(certificate) = certificate { + let cert_file = File::open(&certificate.cert_file)?; + let mut reader = BufReader::new(cert_file); + let certs = rustls_pemfile::certs(&mut reader).flatten(); + root_cert_store.add_parsable_certificates(certs); + } + + let config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let config = Arc::new(config); + let connector = TlsConnector::from(config); + + let domain = match host { + Host::Domain(domain) => ServerName::try_from(domain.to_owned()) + .map_err(|_| Error::InvalidDnsName(domain.to_owned()))?, + Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(ip))), + Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(ip))), + }; + + Ok((connector, domain)) + } + + pub(crate) fn to_hello(&self, version: Version) -> BoltRequest { + HelloBuilder::new(&*self.user, &*self.password) + .with_routing(self.routing.clone()) + .with_version(version) + .build() + } } struct NeoUrl(Url); diff --git a/lib/src/messages.rs b/lib/src/messages.rs index a6987fe..e09e0d5 100644 --- a/lib/src/messages.rs +++ b/lib/src/messages.rs @@ -68,11 +68,14 @@ impl HelloBuilder { } } - pub fn with_routing(&mut self, routing: BoltMap) { - self.routing = Some(routing); + pub fn with_routing(self, routing: impl Into>) -> Self { + Self { + routing: routing.into(), + ..self + } } - pub fn version(self, version: Version) -> Self { + pub fn with_version(self, version: Version) -> Self { Self { version, ..self } } diff --git a/lib/src/version.rs b/lib/src/version.rs index be5a9cf..b6cca9b 100644 --- a/lib/src/version.rs +++ b/lib/src/version.rs @@ -1,5 +1,5 @@ use crate::errors::{Error, Result}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use std::cmp::PartialEq; use std::fmt::Debug; @@ -11,13 +11,12 @@ pub enum Version { } impl Version { - pub fn supported_versions() -> Bytes { - let mut bytes = BytesMut::with_capacity(16); - let versions: [u32; 4] = [0x0104, 0x0004, 0, 0]; - for version in versions.iter() { - bytes.put_u32(*version); - } - bytes.freeze() + pub fn add_supported_versions(bytes: &mut BytesMut) { + bytes.reserve(16); + bytes.put_u32(0x0104); // V4_1 + bytes.put_u32(0x0004); // V4 + bytes.put_u32(0); + bytes.put_u32(0); } pub fn parse(version_bytes: [u8; 4]) -> Result {