Skip to content

Commit cbe130d

Browse files
4.0 subclass signatures fix (#404)
* aligned acquire method for IOPool * more align * alignment
1 parent 9c616e1 commit cbe130d

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

neo4j/io/__init__.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@
8585
PoolConfig,
8686
WorkspaceConfig,
8787
)
88+
from neo4j.api import (
89+
READ_ACCESS,
90+
WRITE_ACCESS,
91+
)
8892

8993
# Set up logger
9094
log = getLogger("neo4j")
@@ -413,11 +417,12 @@ def time_remaining():
413417
raise ClientError("Failed to obtain a connection from pool "
414418
"within {!r}s".format(timeout))
415419

416-
def acquire(self, access_mode=None, timeout=None):
420+
def acquire(self, access_mode=None, timeout=None, database=None):
417421
""" Acquire a connection to a server that can satisfy a set of parameters.
418422
419423
:param access_mode:
420424
:param timeout:
425+
:param database:
421426
"""
422427

423428
def release(self, *connections):
@@ -459,7 +464,7 @@ def deactivate(self, address):
459464
if not connections:
460465
self.remove(address)
461466

462-
def on_write_failure(self, *, address):
467+
def on_write_failure(self, address):
463468
raise WriteServiceUnavailable("No write service available for pool {}".format(self))
464469

465470
def remove(self, address):
@@ -488,7 +493,15 @@ def close(self):
488493
class BoltPool(IOPool):
489494

490495
@classmethod
491-
def open(cls, address, *, auth=None, pool_config, workspace_config):
496+
def open(cls, address, *, auth, pool_config, workspace_config):
497+
"""Create a new BoltPool
498+
499+
:param address:
500+
:param auth:
501+
:param pool_config:
502+
:param workspace_config:
503+
:return: BoltPool
504+
"""
492505

493506
def opener(addr, timeout):
494507
return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config)
@@ -505,7 +518,7 @@ def __init__(self, opener, pool_config, workspace_config, address):
505518
def __repr__(self):
506519
return "<{} address={!r}>".format(self.__class__.__name__, self.address)
507520

508-
def acquire(self, *, access_mode=None, timeout=None, database=None):
521+
def acquire(self, access_mode=None, timeout=None, database=None):
509522
# The access_mode and database is not needed for a direct connection, its just there for consistency.
510523
return self._acquire(self.address, timeout)
511524

@@ -515,7 +528,16 @@ class Neo4jPool(IOPool):
515528
"""
516529

517530
@classmethod
518-
def open(cls, *addresses, auth=None, routing_context=None, pool_config=None, workspace_config=None):
531+
def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=None):
532+
"""Create a new Neo4jPool
533+
534+
:param addresses: one or more address as positional argument
535+
:param auth:
536+
:param pool_config:
537+
:param workspace_config:
538+
:param routing_context:
539+
:return: Neo4jPool
540+
"""
519541

520542
def opener(addr, timeout):
521543
return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config)
@@ -842,7 +864,12 @@ def _select_address(self, *, access_mode, database):
842864
raise WriteServiceUnavailable("No write service currently available")
843865
return choice(addresses_by_usage[min(addresses_by_usage)])
844866

845-
def acquire(self, *, access_mode, timeout, database):
867+
def acquire(self, access_mode=None, timeout=None, database=None):
868+
if access_mode not in (WRITE_ACCESS, READ_ACCESS):
869+
raise ClientError("Non valid 'access_mode'; {}".format(access_mode))
870+
if not timeout:
871+
raise ClientError("'timeout' must be a float larger than 0; {}".format(timeout))
872+
846873
from neo4j.api import check_access_mode
847874
access_mode = check_access_mode(access_mode)
848875
while True:
@@ -859,7 +886,7 @@ def acquire(self, *, access_mode, timeout, database):
859886
else:
860887
return connection
861888

862-
def deactivate(self, *, address):
889+
def deactivate(self, address):
863890
""" Deactivate an address from the connection pool,
864891
if present, remove from the routing table and also closing
865892
all idle connections to that address.
@@ -874,7 +901,7 @@ def deactivate(self, *, address):
874901
log.debug("[#0000] C: <ROUTING> table=%r", self.routing_tables)
875902
super(Neo4jPool, self).deactivate(address)
876903

877-
def on_write_failure(self, *, address):
904+
def on_write_failure(self, address):
878905
""" Remove a writer address from the routing table, if present.
879906
"""
880907
log.debug("[#0000] C: <ROUTING> Removing writer %r", address)

neo4j/io/_bolt3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def send_all(self):
272272
self.server_info.address,
273273
"; ".join(map(repr, error.args))))
274274
if self.pool:
275-
self.pool.deactivate(self.unresolved_address)
275+
self.pool.deactivate(address=self.unresolved_address)
276276
raise
277277

278278
def fetch_message(self):
@@ -302,7 +302,7 @@ def fetch_message(self):
302302
self.server_info.address,
303303
"; ".join(map(repr, error.args))))
304304
if self.pool:
305-
self.pool.deactivate(self.unresolved_address)
305+
self.pool.deactivate(address=self.unresolved_address)
306306
raise
307307

308308
if details:

neo4j/io/_bolt4x0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def send_all(self):
276276
self.server_info.address,
277277
"; ".join(map(repr, error.args))))
278278
if self.pool:
279-
self.pool.deactivate(self.unresolved_address)
279+
self.pool.deactivate(address=self.unresolved_address)
280280
raise
281281

282282
def fetch_message(self):
@@ -306,7 +306,7 @@ def fetch_message(self):
306306
self.server_info.address,
307307
"; ".join(map(repr, error.args))))
308308
if self.pool:
309-
self.pool.deactivate(self.unresolved_address)
309+
self.pool.deactivate(address=self.unresolved_address)
310310
raise
311311

312312
if details:

tests/unit/io/test_direct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def opener(addr, timeout):
9090
super().__init__(opener, self.pool_config, self.workspace_config)
9191
self.address = address
9292

93-
def acquire(self, access_mode=None, timeout=None):
93+
def acquire(self, access_mode=None, timeout=None, database=None):
9494
return self._acquire(self.address, timeout)
9595

9696

0 commit comments

Comments
 (0)