Skip to content

Commit 2651092

Browse files
committed
- Fix silly bug in pool creation
- Remove unused function - Fix the way a connection is created by passing the same scheme as the main one
1 parent 6d4cd66 commit 2651092

File tree

5 files changed

+32
-42
lines changed

5 files changed

+32
-42
lines changed

lib/src/connection.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ impl NeoUrl {
455455
Ok(Self(url))
456456
}
457457

458-
fn scheme(&self) -> &str {
458+
pub(crate) fn scheme(&self) -> &str {
459459
self.0.scheme()
460460
}
461461

lib/src/graph.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,7 @@ impl Graph {
7474
&config.tls_config,
7575
)?;
7676
if matches!(info.routing, Routing::Yes(_)) {
77-
let pool = Routed(
78-
RoutedConnectionManager::new(&config, Arc::new(RoundRobinStrategy::default()))
79-
.await?,
80-
);
77+
let pool = Routed(RoutedConnectionManager::new(&config).await?);
8178
Ok(Graph {
8279
config: config.into_live_config(),
8380
pool,

lib/src/routing/connection_registry.rs

+14-18
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ pub(crate) struct ConnectionRegistry {
4646
}
4747

4848
impl ConnectionRegistry {
49-
pub(crate) async fn new(config: &Config) -> Result<Self, Error> {
50-
let connections = Self::build_registry(config, &[]).await?;
51-
Ok(ConnectionRegistry {
49+
pub(crate) fn new(config: &Config) -> Self {
50+
ConnectionRegistry {
5251
config: config.clone(),
5352
creation_time: Arc::new(Mutex::new(
5453
std::time::SystemTime::now()
@@ -57,20 +56,8 @@ impl ConnectionRegistry {
5756
.as_secs(),
5857
)),
5958
ttl: Arc::new(AtomicU64::new(0)),
60-
connections,
61-
})
62-
}
63-
64-
async fn build_registry(config: &Config, servers: &[BoltServer]) -> Result<Registry, Error> {
65-
let registry = DashMap::new();
66-
for server in servers.iter() {
67-
let server_config = Config {
68-
uri: format!("{}:{}", server.address, server.port), // build a config for each server in the routing table
69-
..config.clone() // but keep the information about tls and other settings
70-
};
71-
registry.insert(server.clone(), create_pool(&server_config).await?);
59+
connections: DashMap::new(),
7260
}
73-
Ok(registry)
7461
}
7562

7663
pub(crate) async fn update_if_expired<F, R>(&self, f: F) -> Result<(), Error>
@@ -90,11 +77,20 @@ impl ConnectionRegistry {
9077
debug!("Routing table refreshed: {:?}", routing_table);
9178
let registry = &self.connections;
9279
let servers = routing_table.resolve();
80+
let url = NeoUrl::parse(self.config.uri.as_str())?;
81+
9382
for server in servers.iter() {
9483
if registry.contains_key(server) {
9584
continue;
9685
}
97-
registry.insert(server.clone(), create_pool(&self.config).await?);
86+
registry.insert(
87+
server.clone(),
88+
create_pool(&Config {
89+
uri: format!("{}://{}:{}", url.scheme(), server.address, server.port),
90+
..self.config.clone()
91+
})
92+
.await?,
93+
);
9894
}
9995
registry.retain(|k, _| servers.contains(k));
10096
let _ = self
@@ -185,7 +181,7 @@ mod tests {
185181
fetch_size: 0,
186182
tls_config: ConnectionTLSConfig::None,
187183
};
188-
let registry = ConnectionRegistry::new(&config).await.unwrap();
184+
let registry = ConnectionRegistry::new(&config);
189185
registry
190186
.update_if_expired(|| async { Ok(cluster_routing_table) })
191187
.await

lib/src/routing/routed_connection_manager.rs

+11-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::connection::{Connection, ConnectionInfo};
22
use crate::pool::ManagedConnection;
33
use crate::routing::connection_registry::{BoltServer, ConnectionRegistry};
44
use crate::routing::load_balancing::LoadBalancingStrategy;
5+
use crate::routing::RoundRobinStrategy;
56
#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
67
use crate::routing::{RouteBuilder, RoutingTable};
78
use crate::{Config, Error, Operation};
@@ -14,19 +15,16 @@ use std::time::Duration;
1415
#[derive(Clone)]
1516
pub struct RoutedConnectionManager {
1617
load_balancing_strategy: Arc<dyn LoadBalancingStrategy>,
17-
registry: Arc<ConnectionRegistry>,
18+
connection_registry: Arc<ConnectionRegistry>,
1819
#[allow(dead_code)]
1920
bookmarks: Arc<Mutex<Vec<String>>>,
2021
backoff: Arc<ExponentialBackoff>,
2122
config: Config,
2223
}
2324

2425
impl RoutedConnectionManager {
25-
pub async fn new(
26-
config: &Config,
27-
load_balancing_strategy: Arc<dyn LoadBalancingStrategy>,
28-
) -> Result<Self, Error> {
29-
let registry = Arc::new(ConnectionRegistry::new(config).await?);
26+
pub async fn new(config: &Config) -> Result<Self, Error> {
27+
let registry = Arc::new(ConnectionRegistry::new(config));
3028
let backoff = Arc::new(
3129
ExponentialBackoffBuilder::new()
3230
.with_initial_interval(Duration::from_millis(1))
@@ -37,8 +35,8 @@ impl RoutedConnectionManager {
3735
);
3836

3937
Ok(RoutedConnectionManager {
40-
load_balancing_strategy,
41-
registry,
38+
load_balancing_strategy: Arc::new(RoundRobinStrategy::default()),
39+
connection_registry: registry,
4240
bookmarks: Arc::new(Mutex::new(vec![])),
4341
backoff,
4442
config: config.clone(),
@@ -70,7 +68,7 @@ impl RoutedConnectionManager {
7068
) -> Result<ManagedConnection, Error> {
7169
// We probably need to do this in a more efficient way, since this will block the request of a connection
7270
// while we refresh the routing table. We should probably have a separate thread that refreshes the routing
73-
self.registry
71+
self.connection_registry
7472
.update_if_expired(|| self.refresh_routing_table())
7573
.await?;
7674

@@ -80,15 +78,15 @@ impl RoutedConnectionManager {
8078
_ => self.select_reader(),
8179
} {
8280
debug!("requesting connection for server: {:?}", server);
83-
if let Some(pool) = self.registry.get_pool(&server) {
81+
if let Some(pool) = self.connection_registry.get_pool(&server) {
8482
match pool.get().await {
8583
Ok(connection) => return Ok(connection),
8684
Err(e) => {
8785
error!(
8886
"Failed to get connection from pool for server `{}`: {}",
8987
server.address, e
9088
);
91-
self.registry.mark_unavailable(&server);
89+
self.connection_registry.mark_unavailable(&server);
9290
continue;
9391
}
9492
}
@@ -115,11 +113,11 @@ impl RoutedConnectionManager {
115113

116114
fn select_reader(&self) -> Option<BoltServer> {
117115
self.load_balancing_strategy
118-
.select_reader(&self.registry.servers())
116+
.select_reader(&self.connection_registry.servers())
119117
}
120118

121119
fn select_writer(&self) -> Option<BoltServer> {
122120
self.load_balancing_strategy
123-
.select_writer(&self.registry.servers())
121+
.select_writer(&self.connection_registry.servers())
124122
}
125123
}

lib/tests/use_default_db.rs

+5-6
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ async fn use_default_db() {
4040
if default_db != dbname {
4141
eprintln!(
4242
concat!(
43-
"Skipping test: The test must run against a testcontainer ",
44-
"or have `{}` configured as the default database"
43+
"Skipping test: The test must run against a testcontainer ",
44+
"or have `{}` configured as the default database"
4545
),
4646
dbname
4747
);
@@ -63,16 +63,15 @@ async fn use_default_db() {
6363
if default_db != dbname {
6464
eprintln!(
6565
concat!(
66-
"Skipping test: The test must run against a testcontainer ",
67-
"or have `{}` configured as the default database"
66+
"Skipping test: The test must run against a testcontainer ",
67+
"or have `{}` configured as the default database"
6868
),
6969
dbname
7070
);
7171
return;
7272
}
7373
}
7474

75-
7675
let id = uuid::Uuid::new_v4();
7776
graph
7877
.run(query("CREATE (:Node { uuid: $uuid })").param("uuid", id.to_string()))
@@ -103,7 +102,7 @@ async fn use_default_db() {
103102
.execute_on(
104103
dbname.as_str(),
105104
query("MATCH (n:Node {uuid: $uuid}) RETURN count(n) AS result")
106-
.param("uuid", id.to_string())
105+
.param("uuid", id.to_string()),
107106
)
108107
.await
109108
.unwrap()

0 commit comments

Comments
 (0)