Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,9 @@ def add_or_renew_pool(self, host, is_host_addition):
return None

def run_add_or_renew_pool():
with self._lock:
previous = self._pools.get(host)

try:
new_pool = HostConnection(host, distance, self)
except AuthenticationFailed as auth_exc:
Expand All @@ -3256,7 +3259,6 @@ def run_add_or_renew_pool():
host, conn_exc, is_host_addition, expect_host_to_be_down=True)
return False

previous = self._pools.get(host)
with self._lock:
while new_pool._keyspace != self.keyspace:
self._lock.release()
Expand All @@ -3276,7 +3278,20 @@ def callback(pool, errors):
self._lock.acquire()
return False
self._lock.acquire()
self._pools[host] = new_pool

pool_unchanged = self._pools.get(host) is previous
if not pool_unchanged:
# Another concurrent add_or_renew_pool changed this host
# while we were creating ours. Don't replace the existing
# pool because doing so would kill in-flight queries.
log.debug("Pool for host %s was already replaced by another "
"thread, discarding new pool", host)
else:
self._pools[host] = new_pool

if not pool_unchanged:
new_pool.shutdown()
return True

log.debug("Added pool for host %s to session", host)
if previous:
Expand All @@ -3287,7 +3302,8 @@ def callback(pool, errors):
return self.submit(run_add_or_renew_pool)

def remove_pool(self, host):
pool = self._pools.pop(host, None)
with self._lock:
pool = self._pools.pop(host, None)
if pool:
log.debug("Removed connection pool for %r", host)
return self.submit(pool.shutdown)
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import socket
import threading

from unittest.mock import patch, Mock
import uuid
Expand All @@ -23,7 +24,7 @@
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
from cassandra.pool import Host
from cassandra.pool import Host, HostConnection
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
from tests.unit.utils import mock_session_pools
Expand Down Expand Up @@ -339,6 +340,39 @@ def test_set_keyspace_escapes_quotes(self, *_):
assert query == 'USE simple_ks', (
"Simple keyspace names should not be quoted, got: %r" % query)


class SessionPoolRaceTest(unittest.TestCase):
def test_concurrent_add_or_renew_pool_no_double_replace(self):
"""Reproduces https://github.com/scylladb/python-driver/issues/317."""
host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())

session = Session.__new__(Session)
session.submit = lambda fn: Mock(result=lambda timeout=None: fn())
session.keyspace = None
session._lock = threading.RLock()
session._pools = {}
session._profile_manager = Mock()
session._profile_manager.distance.return_value = HostDistance.LOCAL

winner_pool = Mock()
created_pools = []

def fake_host_connection_init(pool, *_):
pool._keyspace = session.keyspace
pool.shutdown = Mock()
created_pools.append(pool)
log.info("Publishing competing pool while replacement pool is being created")
with session._lock:
session._pools[host] = winner_pool

with patch.object(HostConnection, '__init__', fake_host_connection_init):
result = session.add_or_renew_pool(host, is_host_addition=True).result()

assert result is True
assert session._pools[host] is winner_pool
created_pools[0].shutdown.assert_called_once()
winner_pool.shutdown.assert_not_called()

class ProtocolVersionTests(unittest.TestCase):

def test_protocol_downgrade_test(self):
Expand Down
Loading