diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..f48f36cfcf 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3241,6 +3241,9 @@ def add_or_renew_pool(self, host, is_host_addition): return None def run_add_or_renew_pool(): + with self._lock: + previous = self._pools.get(host) + try: new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: @@ -3256,7 +3259,6 @@ def run_add_or_renew_pool(): host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False - previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() @@ -3276,7 +3278,20 @@ def callback(pool, errors): self._lock.acquire() return False self._lock.acquire() - self._pools[host] = new_pool + + pool_unchanged = self._pools.get(host) is previous + if not pool_unchanged: + # Another concurrent add_or_renew_pool changed this host + # while we were creating ours. Don't replace the existing + # pool because doing so would kill in-flight queries. + log.debug("Pool for host %s was already replaced by another " + "thread, discarding new pool", host) + else: + self._pools[host] = new_pool + + if not pool_unchanged: + new_pool.shutdown() + return True log.debug("Added pool for host %s to session", host) if previous: @@ -3287,7 +3302,8 @@ def callback(pool, errors): return self.submit(run_add_or_renew_pool) def remove_pool(self, host): - pool = self._pools.pop(host, None) + with self._lock: + pool = self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..a25606c188 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,6 +15,7 @@ import logging import socket +import threading from unittest.mock import patch, Mock import uuid @@ -23,7 +24,7 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.pool import Host +from cassandra.pool import Host, HostConnection from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools @@ -339,6 +340,39 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + +class SessionPoolRaceTest(unittest.TestCase): + def test_concurrent_add_or_renew_pool_no_double_replace(self): + """Reproduces https://github.com/scylladb/python-driver/issues/317.""" + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + session = Session.__new__(Session) + session.submit = lambda fn: Mock(result=lambda timeout=None: fn()) + session.keyspace = None + session._lock = threading.RLock() + session._pools = {} + session._profile_manager = Mock() + session._profile_manager.distance.return_value = HostDistance.LOCAL + + winner_pool = Mock() + created_pools = [] + + def fake_host_connection_init(pool, *_): + pool._keyspace = session.keyspace + pool.shutdown = Mock() + created_pools.append(pool) + log.info("Publishing competing pool while replacement pool is being created") + with session._lock: + session._pools[host] = winner_pool + + with patch.object(HostConnection, '__init__', fake_host_connection_init): + result = session.add_or_renew_pool(host, is_host_addition=True).result() + + assert result is True + assert session._pools[host] is winner_pool + created_pools[0].shutdown.assert_called_once() + winner_pool.shutdown.assert_not_called() + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self):