diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..64b7fa2de7 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2615,6 +2615,12 @@ def __init__(self, cluster, hosts, keyspace=None): self._lock = RLock() self._pools = {} + # Tracks in-flight pool creation futures keyed by host, guarded by + # _lock. Used by add_or_renew_pool to detect and reuse concurrent + # creations so that update_created_pools does not schedule a duplicate + # run_add_or_renew_pool for a host whose pool creation is already + # in-flight (scylladb/python-driver#317). + self._pending_pool_futures = {} self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] @@ -3240,12 +3246,27 @@ def add_or_renew_pool(self, host, is_host_addition): if distance == HostDistance.IGNORED: return None + # Mutable one-element list so the outer code can upgrade the flag + # after the closure has been submitted but before it reads it. This + # fixes the coalescing race where an in-flight future created with + # is_host_addition=False is reused by a later on_add() call that + # needs is_host_addition=True: the closure then passes the wrong flag + # to signal_connection_failure(), causing _HostReconnectionHandler to + # call on_up() instead of on_add() on reconnect (scylladb/python-driver#317). + is_host_addition_cell = [is_host_addition] + + # Unique token for this submission. The closure checks it before + # installing its pool so that a stale task (whose entry was replaced by + # remove_pool + a fresh add_or_renew_pool) discards its pool rather + # than overwriting the freshly-started one (scylladb/python-driver#317). + creation_id = object() + def run_add_or_renew_pool(): try: new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + self.cluster.signal_connection_failure(host, conn_exc, is_host_addition_cell[0]) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", @@ -3253,10 +3274,9 @@ def run_add_or_renew_pool(): # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + host, conn_exc, is_host_addition_cell[0], expect_host_to_be_down=True) return False - previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() @@ -3271,12 +3291,41 @@ def callback(pool, errors): set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) + self.cluster.on_down(host, is_host_addition_cell[0]) new_pool.shutdown() self._lock.acquire() return False self._lock.acquire() - self._pools[host] = new_pool + + # Identity guard: if _pending_pool_futures no longer holds our + # creation_id it means remove_pool() ran (and possibly a fresh + # add_or_renew_pool was submitted) while we were connecting. + # Discard our pool so the fresher task can install its own + # (scylladb/python-driver#317). + entry = self._pending_pool_futures.get(host) + if entry is None or entry['creation_id'] is not creation_id: + log.debug("Discarding stale connection pool for host %s " + "(superseded by a newer creation)", host) + discard_new_pool = True + else: + # Read the current pool state inside the lock so the check + # is atomic with the installation of our new pool. + previous = self._pools.get(host) + if previous is not None and not previous.is_shutdown: + # A concurrent add_or_renew_pool already installed a + # live pool for this host while we were connecting. + # Discard ours to avoid replacing it and dropping + # in-flight requests (scylladb/python-driver#317). + log.debug("Discarding duplicate connection pool for host %s " + "(live pool already present)", host) + discard_new_pool = True + else: + discard_new_pool = False + self._pools[host] = new_pool + + if discard_new_pool: + new_pool.shutdown() + return True log.debug("Added pool for host %s to session", host) if previous: @@ -3284,10 +3333,69 @@ def callback(pool, errors): return True - return self.submit(run_add_or_renew_pool) + with self._lock: + if self.is_shutdown: + return None + # If there is already an in-flight pool creation for this host, + # return that future instead of scheduling a duplicate. This + # prevents update_created_pools from creating a second pool when + # the first one has not yet finished connecting + # (scylladb/python-driver#317). + entry = self._pending_pool_futures.get(host) + if entry is not None and entry['future'] is not None and not entry['future'].done(): + if distance == entry['distance']: + # Same distance: safe to coalesce. Upgrade is_host_addition + # in the shared cell if the new caller needs the stricter + # on_add() reconnect path (scylladb/python-driver#317). + if is_host_addition: + entry['is_host_addition_cell'][0] = True + log.debug("Reusing in-flight pool creation for host %s", host) + return entry['future'] + # Distance changed: the in-flight HostConnection was constructed + # with the old distance (e.g. REMOTE with connect_to_remote_hosts + # =False => no connections). Submit a fresh task; the creation_id + # guard below ensures it wins over the stale one + # (scylladb/python-driver#317). + log.debug("Distance changed for host %s while pool creation was " + "in-flight; submitting fresh creation", host) + # Store the entry BEFORE calling submit so the closure always + # finds a valid creation_id in _pending_pool_futures, even when + # the executor runs the task synchronously + # (scylladb/python-driver#317). + new_entry = { + 'future': None, # filled in immediately after submit returns + 'creation_id': creation_id, + 'distance': distance, + 'is_host_addition_cell': is_host_addition_cell, + } + self._pending_pool_futures[host] = new_entry + future = self.submit(run_add_or_renew_pool) + if future is None: + # Session is shutting down; clean up the placeholder entry. + self._pending_pool_futures.pop(host, None) + return None + new_entry['future'] = future + # Remove the entry once the future finishes, regardless of how + # run_add_or_renew_pool exits (including unhandled exceptions). + # The callback acquires _lock and only clears the entry if it + # still holds *this* creation_id, so a concurrent remove_pool + # followed by a new add_or_renew_pool is not affected + # (scylladb/python-driver#317). + def _clear_pending(f, _host=host, _creation_id=creation_id): + with self._lock: + e = self._pending_pool_futures.get(_host) + if e is not None and e['creation_id'] is _creation_id: + self._pending_pool_futures.pop(_host, None) + future.add_done_callback(_clear_pending) + return future def remove_pool(self, host): - pool = self._pools.pop(host, None) + with self._lock: + pool = self._pools.pop(host, None) + # Invalidate any in-flight pool creation for this host so that a + # subsequent update_created_pools call can schedule a fresh one if + # needed (scylladb/python-driver#317). + self._pending_pool_futures.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..803ef8329b 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,6 +15,8 @@ import logging import socket +from concurrent.futures import Future +from threading import RLock from unittest.mock import patch, Mock import uuid @@ -339,6 +341,224 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + +class PoolRenewalRaceTest(unittest.TestCase): + """ + Regression tests for scylladb/python-driver#317: connection pool renewal + after concurrent node bootstraps causes double statement execution. + """ + + def _make_session(self): + """ + Return a minimal Session with the attributes needed to exercise + add_or_renew_pool / remove_pool, without actually opening any network + connections. + """ + s = object.__new__(Session) + s._lock = RLock() + s._pools = {} + s._pending_pool_futures = {} + s.is_shutdown = False + s.keyspace = None + s._profile_manager = Mock() + s._profile_manager.distance.return_value = HostDistance.LOCAL + s.cluster = Mock() + s.cluster.executor = Mock() + # submit() delegates to cluster.executor.submit; return a done future + # by default so callers that inspect the result don't hang. + done_future = Future() + done_future.set_result(True) + s.cluster.executor.submit.return_value = done_future + return s + + def test_add_or_renew_pool_reuses_inflight_future(self): + """ + When add_or_renew_pool is called for a host that already has an + in-flight pool creation (tracked in _pending_pool_futures), it must + return the existing future instead of submitting a duplicate task. + Without this fix, a concurrent call from update_created_pools would + create a second HostConnection pool, then shut down the first one + while requests were still in-flight, causing those requests to be + retried and executed twice on the server side. + """ + s = self._make_session() + host = Mock() + host.is_up = True + + # Simulate an in-flight pool creation by placing a pending (not-yet- + # resolved) future directly in _pending_pool_futures using the + # dict-based entry format introduced by the coalescing fix. + inflight_future = Future() # not set_result yet → still in-flight + s._pending_pool_futures[host] = { + 'future': inflight_future, + 'creation_id': object(), + 'distance': HostDistance.LOCAL, + 'is_host_addition_cell': [False], + } + + returned = s.add_or_renew_pool(host, is_host_addition=False) + + # The call must reuse the existing in-flight future, not submit a new one. + assert returned is inflight_future, ( + "add_or_renew_pool should return the existing in-flight future, " + "not create a duplicate pool creation task" + ) + s.cluster.executor.submit.assert_not_called() + + def test_add_or_renew_pool_discards_duplicate_when_live_pool_exists(self): + """ + Defense-in-depth for scylladb/python-driver#317. + + When run_add_or_renew_pool finishes creating a new pool but finds that + a live pool has already been installed for the host by a concurrent + creation, the new pool must be discarded (shut down) rather than + replacing the live one. Replacing a live pool would close it while + requests are still in-flight, causing server-side double execution. + + This test exercises the real production code path by stubbing + HostConnection and running the submitted callable synchronously. + """ + s = self._make_session() + host = Mock() + host.is_up = True + + # Pre-install a live pool for this host to simulate the state left by + # a concurrent add_or_renew_pool that finished first. + live_pool = Mock() + live_pool.is_shutdown = False + s._pools[host] = live_pool + + # Make the executor run the submitted callable synchronously so the + # test does not need threads. + def sync_submit(fn, *args, **kwargs): + result = fn(*args, **kwargs) + f = Future() + f.set_result(result) + return f + s.cluster.executor.submit = sync_submit + + # Stub HostConnection so no real TCP connection is opened. + # _keyspace must equal s.keyspace (None) so the keyspace-sync loop + # inside run_add_or_renew_pool is skipped. + stub_pool = Mock() + stub_pool._keyspace = None + + with patch('cassandra.cluster.HostConnection', return_value=stub_pool): + s.add_or_renew_pool(host, is_host_addition=False) + + # The pre-installed live pool must not have been replaced. + assert s._pools[host] is live_pool, ( + "add_or_renew_pool must not replace a live pool that is already " + "present when the new connection finishes" + ) + # The newly created pool stub must have been shut down. + stub_pool.shutdown.assert_called_once() + + def test_remove_pool_clears_pending_future(self): + """ + remove_pool must clear _pending_pool_futures for the host so that a + subsequent update_created_pools call can schedule a fresh pool + creation if needed (instead of reusing a now-stale in-flight future + for a host that has been removed and re-added). + """ + s = self._make_session() + host = Mock() + + stale_future = Future() + s._pending_pool_futures[host] = { + 'future': stale_future, + 'creation_id': object(), + 'distance': HostDistance.LOCAL, + 'is_host_addition_cell': [False], + } + + pool = Mock() + s._pools[host] = pool + + s.remove_pool(host) + + assert host not in s._pending_pool_futures, ( + "remove_pool must clear _pending_pool_futures so the next " + "add_or_renew_pool call submits a fresh task" + ) + + def test_done_callback_clears_pending_future(self): + """ + The done-callback registered by add_or_renew_pool must remove the host + entry from _pending_pool_futures once the future completes, so that + update_created_pools can schedule a fresh creation on the next call + rather than treating a stale done future as in-flight. + """ + s = self._make_session() + host = Mock() + host.is_up = True + + returned = s.add_or_renew_pool(host, is_host_addition=False) + assert returned is not None + + # The future submitted by add_or_renew_pool is already done (our mock + # executor returns a pre-resolved future), so the done-callback has + # already fired. + assert host not in s._pending_pool_futures, ( + "done-callback should have cleared _pending_pool_futures once " + "the future completed" + ) + + def test_done_callback_does_not_clear_newer_future(self): + """ + The done-callback must only clear _pending_pool_futures[host] if the + entry still points at the *same* future it was registered on. If a + newer future has been installed in the meantime (e.g. after remove_pool + + add_or_renew_pool), the callback must leave the new entry alone. + """ + s = self._make_session() + host = Mock() + host.is_up = True + + # Place an entry manually and register the callback as the real code + # would, but keep the future pending so the callback has not fired yet. + old_future = Future() + new_future = Future() + old_creation_id = object() + new_creation_id = object() + + with s._lock: + s._pending_pool_futures[host] = { + 'future': old_future, + 'creation_id': old_creation_id, + 'distance': HostDistance.LOCAL, + 'is_host_addition_cell': [False], + } + + def _clear_pending(f, _host=host, _creation_id=old_creation_id): + with s._lock: + e = s._pending_pool_futures.get(_host) + if e is not None and e['creation_id'] is _creation_id: + s._pending_pool_futures.pop(_host, None) + + old_future.add_done_callback(_clear_pending) + + # Simulate remove_pool + a new add_or_renew_pool: replace with a newer + # entry (new creation_id) before old_future completes. + with s._lock: + s._pending_pool_futures[host] = { + 'future': new_future, + 'creation_id': new_creation_id, + 'distance': HostDistance.LOCAL, + 'is_host_addition_cell': [False], + } + + # Now complete the old future — its callback must not evict new entry. + old_future.set_result(True) + + with s._lock: + entry = s._pending_pool_futures.get(host) + assert entry is not None and entry['future'] is new_future, ( + "done-callback of an old future must not remove a newer " + "pending entry from _pending_pool_futures" + ) + + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self):