Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor connection creation #192

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ features = ["managed"]
[dependencies.tokio-rustls]
version = "0.26.0"
default-features = false
features = ["tls12"]
features = ["tls12", "ring"]


[dev-dependencies]
Expand Down
248 changes: 125 additions & 123 deletions lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,19 @@
use crate::bolt::{ExpectedResponse, Message, MessageResponse};
use crate::{
auth::ClientCertificate,
connection::stream::ConnectionStream,
errors::{Error, Result},
messages::{BoltRequest, BoltResponse, HelloBuilder},
version::Version,
BoltMap,
};
use bytes::{Bytes, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use log::warn;
use std::fs::File;
use std::io::BufReader;
use std::{mem, sync::Arc};
use stream::ConnectionStream;
use std::{fs::File, io::BufReader, mem, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufStream},
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream},
net::TcpStream,
};
use tokio_rustls::client::TlsStream;
use tokio_rustls::{
rustls::{
pki_types::{IpAddr, Ipv4Addr, Ipv6Addr, ServerName},
Expand All @@ -36,114 +33,60 @@ pub struct Connection {
}

impl Connection {
pub(crate) async fn new(info: &ConnectionInfo) -> Result<Connection> {
let mut hello_builder = HelloBuilder::new(&*info.user, &*info.password);
if let Routing::Yes(routing) = &info.routing {
hello_builder.with_routing(routing.clone());
};
pub(crate) async fn new(info: &ConnectionInfo) -> Result<Self> {
let mut connection = Self::prepare(info).await?;
let hello = info.to_hello(connection.version);
connection.hello(hello).await?;
Ok(connection)
}

let stream = match &info.host {
pub(crate) async fn prepare(info: &ConnectionInfo) -> Result<Self> {
let mut stream = match &info.host {
Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?,
Host::Ipv4(ip) => TcpStream::connect((*ip, info.port)).await?,
Host::Ipv6(ip) => TcpStream::connect((*ip, info.port)).await?,
};

match info.encryption {
Encryption::No => Self::new_unencrypted(stream, hello_builder).await,
Encryption::Tls => {
if let Some(certificate) = info.client_certificate.as_ref() {
Self::new_tls_with_certificate(stream, &info.host, hello_builder, certificate)
.await
} else {
Self::new_tls(stream, &info.host, hello_builder).await
}
Ok(match &info.encryption {
Some((connector, domain)) => {
let mut stream = connector.connect(domain.clone(), stream).await?;
let version = Self::init(&mut stream).await?;
Self::create(stream, version)
}
}
}

async fn new_unencrypted(stream: TcpStream, hello_builder: HelloBuilder) -> Result<Connection> {
Self::init(hello_builder, stream).await
}

async fn new_tls<T: AsRef<str>>(
stream: TcpStream,
host: &Host<T>,
hello_builder: HelloBuilder,
) -> Result<Connection> {
let root_cert_store = Self::build_cert_store();
let stream = Self::build_stream(stream, host, root_cert_store).await?;

Self::init(hello_builder, stream).await
None => {
let version = Self::init(&mut stream).await?;
Self::create(stream, version)
}
})
}

async fn new_tls_with_certificate<T: AsRef<str>>(
stream: TcpStream,
host: &Host<T>,
hello_builder: HelloBuilder,
certificate: &ClientCertificate,
) -> Result<Connection> {
let mut root_cert_store = Self::build_cert_store();

let cert_file = File::open(certificate.cert_file.as_os_str())?;
let mut reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut reader).flatten();
root_cert_store.add_parsable_certificates(certs);
async fn init<A: AsyncWrite + AsyncRead + Unpin>(stream: &mut A) -> Result<Version> {
stream.write_all_buf(&mut Self::init_msg()).await?;
stream.flush().await?;

let stream = Self::build_stream(stream, host, root_cert_store).await?;
Self::init(hello_builder, stream).await
let mut response = [0, 0, 0, 0];
stream.read_exact(&mut response).await?;
let version = Version::parse(response)?;
Ok(version)
}

fn build_cert_store() -> RootCertStore {
let mut root_cert_store = RootCertStore::empty();
match rustls_native_certs::load_native_certs() {
Ok(certs) => {
root_cert_store.add_parsable_certificates(certs);
}
Err(e) => {
warn!("Failed to load native certificates: {e}");
}
fn create(stream: impl Into<ConnectionStream>, version: Version) -> Connection {
Connection {
version,
stream: BufStream::new(stream.into()),
}
root_cert_store
}

async fn build_stream<T: AsRef<str>>(
stream: TcpStream,
host: &Host<T>,
root_cert_store: RootCertStore,
) -> Result<TlsStream<TcpStream>, Error> {
let config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();

let config = Arc::new(config);
let connector = TlsConnector::from(config);

let domain = match host {
Host::Domain(domain) => ServerName::try_from(domain.as_ref().to_owned())
.map_err(|_| Error::InvalidDnsName(domain.as_ref().to_owned()))?,
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(*ip))),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(*ip))),
};

let stream = connector.connect(domain, stream).await?;
Ok(stream)
fn init_msg() -> Bytes {
let mut init = BytesMut::with_capacity(20);
init.put_slice(&[0x60, 0x60, 0xB0, 0x17]);
Version::add_supported_versions(&mut init);
init.freeze()
}

async fn init(
hello_builder: HelloBuilder,
stream: impl Into<ConnectionStream>,
) -> Result<Connection> {
let mut stream = BufStream::new(stream.into());
stream.write_all(&[0x60, 0x60, 0xB0, 0x17]).await?;
stream.write_all(&Version::supported_versions()).await?;
stream.flush().await?;
let mut response = [0, 0, 0, 0];
stream.read_exact(&mut response).await?;
let version = Version::parse(response)?;
let mut connection = Connection { version, stream };
let hello = hello_builder.version(version).build();
match connection.send_recv(hello).await? {
BoltResponse::Success(_msg) => Ok(connection),
async fn hello(&mut self, req: BoltRequest) -> Result<()> {
match self.send_recv(req).await? {
BoltResponse::Success(_msg) => Ok(()),
BoltResponse::Failure(msg) => {
Err(Error::AuthenticationError(msg.get("message").unwrap()))
}
Expand Down Expand Up @@ -255,29 +198,43 @@ impl Connection {
}
}

#[derive(Debug)]
pub(crate) struct ConnectionInfo {
user: Arc<str>,
password: Arc<str>,
host: Host<Arc<str>>,
port: u16,
encryption: Encryption,
routing: Routing,
client_certificate: Option<ClientCertificate>,
encryption: Option<(TlsConnector, ServerName<'static>)>,
}

#[derive(Debug)]
enum Encryption {
No,
Tls,
impl std::fmt::Debug for ConnectionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionInfo")
.field("user", &self.user)
.field("password", &"***")
.field("host", &self.host)
.field("port", &self.port)
.field("routing", &self.routing)
.field("encryption", &self.encryption.is_some())
.finish_non_exhaustive()
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) enum Routing {
No,
Yes(BoltMap),
}

impl From<Routing> for Option<BoltMap> {
fn from(routing: Routing) -> Self {
match routing {
Routing::No => None,
Routing::Yes(routing) => Some(routing),
}
}
}

impl ConnectionInfo {
pub(crate) fn new(
uri: &str,
Expand All @@ -287,25 +244,20 @@ impl ConnectionInfo {
) -> Result<Self> {
let mut url = NeoUrl::parse(uri)?;

let host = url.host();
let host = match host {
Host::Domain(s) => Host::Domain(Arc::<str>::from(s)),
Host::Ipv4(d) => Host::Ipv4(d),
Host::Ipv6(d) => Host::Ipv6(d),
};

let port = url.port();

let (routing, encryption) = match url.scheme() {
"bolt" | "" => (false, Encryption::No),
"bolt+s" => (false, Encryption::Tls),
"bolt+ssc" => (false, Encryption::Tls),
"neo4j" => (true, Encryption::No),
"neo4j+s" => (true, Encryption::Tls),
"neo4j+ssc" => (true, Encryption::Tls),
"bolt" | "" => (false, false),
"bolt+s" => (false, true),
"bolt+ssc" => (false, true),
"neo4j" => (true, false),
"neo4j+s" => (true, true),
"neo4j+ssc" => (true, true),
otherwise => return Err(Error::UnsupportedScheme(otherwise.to_owned())),
};

let encryption = encryption
.then(|| Self::tls_connector(url.host(), client_certificate))
.transpose()?;

let routing = if routing {
log::warn!(concat!(
"This driver does not yet implement client-side routing. ",
Expand All @@ -318,16 +270,66 @@ impl ConnectionInfo {

url.warn_on_unexpected_components();

let host = match url.host() {
Host::Domain(s) => Host::Domain(Arc::<str>::from(s)),
Host::Ipv4(d) => Host::Ipv4(d),
Host::Ipv6(d) => Host::Ipv6(d),
};

Ok(Self {
user: user.into(),
password: password.into(),
host,
port,
port: url.port(),
encryption,
routing,
client_certificate: client_certificate.cloned(),
})
}

fn tls_connector(
host: Host<&str>,
certificate: Option<&ClientCertificate>,
) -> Result<(TlsConnector, ServerName<'static>)> {
let mut root_cert_store = RootCertStore::empty();
match rustls_native_certs::load_native_certs() {
Ok(certs) => {
root_cert_store.add_parsable_certificates(certs);
}
Err(e) => {
warn!("Failed to load native certificates: {e}");
}
}

if let Some(certificate) = certificate {
let cert_file = File::open(&certificate.cert_file)?;
let mut reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut reader).flatten();
root_cert_store.add_parsable_certificates(certs);
}

let config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();

let config = Arc::new(config);
let connector = TlsConnector::from(config);

let domain = match host {
Host::Domain(domain) => ServerName::try_from(domain.to_owned())
.map_err(|_| Error::InvalidDnsName(domain.to_owned()))?,
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from(ip))),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(Ipv6Addr::from(ip))),
};

Ok((connector, domain))
}

pub(crate) fn to_hello(&self, version: Version) -> BoltRequest {
HelloBuilder::new(&*self.user, &*self.password)
.with_routing(self.routing.clone())
.with_version(version)
.build()
}
}

struct NeoUrl(Url);
Expand Down
9 changes: 6 additions & 3 deletions lib/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ impl HelloBuilder {
}
}

pub fn with_routing(&mut self, routing: BoltMap) {
self.routing = Some(routing);
pub fn with_routing(self, routing: impl Into<Option<BoltMap>>) -> Self {
Self {
routing: routing.into(),
..self
}
}

pub fn version(self, version: Version) -> Self {
pub fn with_version(self, version: Version) -> Self {
Self { version, ..self }
}

Expand Down
15 changes: 7 additions & 8 deletions lib/src/version.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::errors::{Error, Result};
use bytes::{BufMut, Bytes, BytesMut};
use bytes::{BufMut, BytesMut};
use std::cmp::PartialEq;
use std::fmt::Debug;

Expand All @@ -11,13 +11,12 @@ pub enum Version {
}

impl Version {
pub fn supported_versions() -> Bytes {
let mut bytes = BytesMut::with_capacity(16);
let versions: [u32; 4] = [0x0104, 0x0004, 0, 0];
for version in versions.iter() {
bytes.put_u32(*version);
}
bytes.freeze()
pub fn add_supported_versions(bytes: &mut BytesMut) {
bytes.reserve(16);
bytes.put_u32(0x0104); // V4_1
bytes.put_u32(0x0004); // V4
bytes.put_u32(0);
bytes.put_u32(0);
}

pub fn parse(version_bytes: [u8; 4]) -> Result<Version> {
Expand Down