From ccf262c4d28e100f6a0df8666a347c31e21dcb6a Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 2 May 2026 09:14:40 -0400 Subject: [PATCH 01/29] cluster: introduce event fence state --- cassandra/cluster.py | 99 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..f7147be593 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -251,6 +251,105 @@ def _discard_cluster_shutdown(cluster): _clusters_for_shutdown.discard(cluster) +class _EventFenceState(object): + """ + Tracks a monotonically increasing epoch and named in-flight events. + + The caller owns any domain-specific locking around this state. The epoch + represents the latest observed event for a key, while named events record + which epoch currently owns a deferred side effect. + """ + + __slots__ = ("epoch", "_events") + + def __init__(self): + self.epoch = 0 + self._events = {} + + def advance(self): + self.epoch += 1 + return self.epoch + + def get_event(self, event): + return self._events.get(event) + + def set_event(self, event, epoch=None): + if epoch is None: + epoch = self.epoch + self._events[event] = epoch + return epoch + + def clear_event(self, event, epoch=None): + if epoch is not None and self._events.get(event) != epoch: + return False + self._events.pop(event, None) + return True + + def event_is_current(self, event, epoch): + return self._events.get(event) == epoch and self.epoch == epoch + + def _set_or_clear_event(self, event, epoch): + if epoch is None: + self.clear_event(event) + else: + self.set_event(event, epoch) + + +class _IdentityWeakKeyDictionary(object): + """ + Weak mapping that uses object identity instead of ``__hash__``/``__eq__``. + Host hashes are endpoint-based, and endpoints can change during ring refresh. + """ + + def __init__(self): + self._items = {} + + def get(self, key, default=None): + key_id = id(key) + item = self._items.get(key_id) + if item is None: + return default + + key_ref, value = item + if key_ref() is key: + return value + + self._items.pop(key_id, None) + return default + + def __setitem__(self, key, value): + key_id = id(key) + self_ref = weakref.ref(self) + + def remove(ref): + self = self_ref() + if self is not None: + item = self._items.get(key_id) + if item is not None and item[0] is ref: + self._items.pop(key_id, None) + + self._items[key_id] = (weakref.ref(key, remove), value) + + +class _EventFenceMap(object): + """ + Identity-keyed store for per-object event fence state. + """ + + def __init__(self, state_factory=_EventFenceState): + self._lock = Lock() + self._states = _IdentityWeakKeyDictionary() + self._state_factory = state_factory + + def get_state(self, key): + with self._lock: + state = self._states.get(key) + if state is None: + state = self._state_factory() + self._states[key] = state + return state + + def _shutdown_clusters(): clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" for cluster in clusters: From c3ad1ec92d68c2ec5ef25f1da9657c4dbd01f785 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 2 May 2026 09:14:52 -0400 Subject: [PATCH 02/29] cluster: replay up events after down handling --- cassandra/cluster.py | 309 +++++++++++-- cassandra/pool.py | 4 - tests/unit/test_cluster.py | 885 ++++++++++++++++++++++++++++++++++++- 3 files changed, 1165 insertions(+), 33 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index f7147be593..526f971f8c 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -234,6 +234,7 @@ def new_f(self, *args, **kwargs): try: future = self.executor.submit(f, self, *args, **kwargs) future.add_done_callback(_future_completed) + return future except Exception: log.exception("Failed to submit task to executor") @@ -295,6 +296,38 @@ def _set_or_clear_event(self, event, epoch): self.set_event(event, epoch) +class _HostLivenessState(_EventFenceState): + _UP = "up" + _DOWN = "down" + _PENDING_UP = "pending_up" + + __slots__ = () + + @property + def up_epoch(self): + return self.get_event(self._UP) + + @up_epoch.setter + def up_epoch(self, epoch): + self._set_or_clear_event(self._UP, epoch) + + @property + def down_epoch(self): + return self.get_event(self._DOWN) + + @down_epoch.setter + def down_epoch(self, epoch): + self._set_or_clear_event(self._DOWN, epoch) + + @property + def pending_up_epoch(self): + return self.get_event(self._PENDING_UP) + + @pending_up_epoch.setter + def pending_up_epoch(self, epoch): + self._set_or_clear_event(self._PENDING_UP, epoch) + + class _IdentityWeakKeyDictionary(object): """ Weak mapping that uses object identity instead of ``__hash__``/``__eq__``. @@ -1588,6 +1621,7 @@ def __init__(self, self.sessions = WeakSet() self.metadata = Metadata() self.control_connection = None + self._host_liveness = _EventFenceMap(_HostLivenessState) self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() @@ -1965,11 +1999,86 @@ def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): - session.remove_pool(host) + session.remove_pool(host, expected_host=host) self._start_reconnector(host, is_host_addition=False) - def _on_up_future_completed(self, host, futures, results, lock, finished_future): + def _get_host_liveness_state(self, host): + try: + fences = self._host_liveness + except AttributeError: + fences = self._host_liveness = _EventFenceMap(_HostLivenessState) + return fences.get_state(host) + + def _up_handling_was_superseded(self, host, up_epoch): + state = self._get_host_liveness_state(host) + return not state.event_is_current(_HostLivenessState._UP, up_epoch) + + def _up_handling_is_superseded(self, host, up_epoch): + with host.lock: + superseded = self._up_handling_was_superseded(host, up_epoch) + if superseded: + log.debug("Ignoring superseded up handling for node %s", host) + return superseded + + def _get_reconnector_for_current_up_handling(self, host, up_epoch): + with host.lock: + if self._up_handling_was_superseded(host, up_epoch): + log.debug("Ignoring superseded up handling for node %s", host) + return None, True + reconnector = host._reconnection_handler + host._reconnection_handler = None + return reconnector, False + + def _clear_up_handling(self, host, up_epoch=None): + state = self._get_host_liveness_state(host) + return state.clear_event(_HostLivenessState._UP, up_epoch) + + def _cleanup_superseded_up_handling(self, host): + for session in tuple(self.sessions): + session.remove_pool(host, expected_host=host) + + def _pop_pending_node_up_if_ready(self, host): + state = self._get_host_liveness_state(host) + if state.pending_up_epoch is None: + return None + if host.is_up: + state.pending_up_epoch = None + return None + if state.up_epoch is not None or state.down_epoch is not None: + return None + + pending_up_epoch = state.pending_up_epoch + # Leave the pending marker in place until on_up() reacquires host.lock so + # a newer down signal can still invalidate this replay. + return pending_up_epoch + + def _handle_pending_node_up(self, host, pending_up_epoch): + if pending_up_epoch is not None: + log.debug("Handling queued up status of node %s", host) + self._on_up(host, expected_epoch=pending_up_epoch) + + def _clear_down_handling(self, host, down_epoch=None): + state = self._get_host_liveness_state(host) + return state.clear_event(_HostLivenessState._DOWN, down_epoch) + + def _finish_superseded_up_handling(self, host, up_epoch): + self._cleanup_superseded_up_handling(host) + + pending_up_epoch = None + with host.lock: + if self._clear_up_handling(host, up_epoch): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + + self._handle_pending_node_up(host, pending_up_epoch) + + def _finish_up_if_superseded(self, host, up_epoch): + if self._up_handling_is_superseded(host, up_epoch): + self._finish_superseded_up_handling(host, up_epoch) + return True + return False + + def _on_up_future_completed(self, host, up_handling_revision, futures, results, lock, finished_future): with lock: futures.discard(finished_future) @@ -1983,6 +2092,9 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) try: # all futures have completed at this point + if self._finish_up_if_superseded(host, up_handling_revision): + return + for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) self._cleanup_failed_on_up_handling(host) @@ -1995,18 +2107,35 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners - host.set_up() + superseded = False + with host.lock: + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Ignoring superseded up handling for node %s", host) + superseded = True + else: + host.set_up() + self._clear_up_handling(host, up_handling_revision) + if superseded: + self._finish_superseded_up_handling(host, up_handling_revision) + return for listener in self.listeners: listener.on_up(host) finally: + pending_up_epoch = None with host.lock: - host._currently_handling_node_up = False + if self._clear_up_handling(host, up_handling_revision): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + self._handle_pending_node_up(host, pending_up_epoch) # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() + return def on_up(self, host): + return self._on_up(host) + + def _on_up(self, host, expected_epoch=None): """ Intended for internal use only. """ @@ -2015,15 +2144,35 @@ def on_up(self, host): log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: - if host._currently_handling_node_up: - log.debug("Another thread is already handling up status of node %s", host) + state = self._get_host_liveness_state(host) + if (expected_epoch is not None and + (state.epoch != expected_epoch or state.pending_up_epoch != expected_epoch)): + log.debug("Ignoring stale queued up handling for node %s", host) + return + + if state.down_epoch is not None: + log.debug("Down status is being handled for node %s; queueing up handling", host) + state.pending_up_epoch = state.epoch + return + + if state.up_epoch is not None: + up_handling_revision = state.up_epoch + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Superseded up handling is still finishing for node %s; " + "queueing up handling", host) + state.pending_up_epoch = state.epoch + else: + log.debug("Another thread is already handling up status of node %s", host) return if host.is_up: log.debug("Host %s was already marked up", host) + state.pending_up_epoch = None return - host._currently_handling_node_up = True + state.pending_up_epoch = None + up_handling_revision = state.epoch + state.up_epoch = up_handling_revision log.debug("Starting to handle up status of node %s", host) have_future = False @@ -2031,28 +2180,44 @@ def on_up(self, host): try: log.info("Host %s may be up; will prepare queries and open connection pool", host) - reconnector = host.get_and_set_reconnection_handler(None) + reconnector, superseded = self._get_reconnector_for_current_up_handling( + host, up_handling_revision) + if superseded: + self._finish_superseded_up_handling(host, up_handling_revision) + return futures if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() if self.profile_manager.distance(host) != HostDistance.IGNORED: self._prepare_all_queries(host) + if self._finish_up_if_superseded(host, up_handling_revision): + return futures log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): - session.remove_pool(host) + session.remove_pool(host, expected_host=host) + + if self._finish_up_if_superseded(host, up_handling_revision): + return futures log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) + if self._finish_up_if_superseded(host, up_handling_revision): + return futures + log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) + if self._finish_up_if_superseded(host, up_handling_revision): + return futures + log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] - callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback = partial(self._on_up_future_completed, host, up_handling_revision, + futures, futures_results, futures_lock) for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: @@ -2066,19 +2231,29 @@ def on_up(self, host): self._cleanup_failed_on_up_handling(host) + pending_up_epoch = None with host.lock: - host._currently_handling_node_up = False + if self._clear_up_handling(host, up_handling_revision): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + self._handle_pending_node_up(host, pending_up_epoch) raise else: if not have_future: + superseded = False with host.lock: - host.set_up() - host._currently_handling_node_up = False + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Ignoring superseded up handling for node %s", host) + superseded = True + else: + host.set_up() + self._clear_up_handling(host, up_handling_revision) + if superseded: + self._finish_superseded_up_handling(host, up_handling_revision) # for testing purposes return futures - def _start_reconnector(self, host, is_host_addition): + def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None): if self.profile_manager.distance(host) == HostDistance.IGNORED: return @@ -2094,7 +2269,15 @@ def _start_reconnector(self, host, is_host_addition): self.scheduler, schedule, host.get_and_set_reconnection_handler, new_handler=None) - old_reconnector = host.get_and_set_reconnection_handler(reconnector) + with host.lock: + if expected_down_epoch is not None: + state = self._get_host_liveness_state(host) + if state.down_epoch != expected_down_epoch: + log.debug("Not starting reconnector for host %s; down handling is no longer current", host) + return + + old_reconnector = host._reconnection_handler + host._reconnection_handler = reconnector if old_reconnector: log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() @@ -2103,16 +2286,45 @@ def _start_reconnector(self, host, is_host_addition): reconnector.start() @run_in_executor - def on_down_potentially_blocking(self, host, is_host_addition): - self.profile_manager.on_down(host) - self.control_connection.on_down(host) - for session in tuple(self.sessions): - session.on_down(host) + def on_down_potentially_blocking( + self, host: Host, is_host_addition: bool, + down_epoch: Optional[int] = None) -> Any: + pending_up_epoch = None + with host.lock: + state = self._get_host_liveness_state(host) + owns_reserved_down_handling = down_epoch is not None and state.down_epoch == down_epoch + if down_epoch is None: + if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + log.debug("Ignoring stale down handling for host %s", host) + return + down_epoch = state.epoch + state.down_epoch = down_epoch + elif not owns_reserved_down_handling: + log.debug("Ignoring stale down handling for host %s", host) + return - for listener in self.listeners: - listener.on_down(host) + try: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + for session in tuple(self.sessions): + session.on_down(host) + + for listener in self.listeners: + listener.on_down(host) + + with host.lock: + start_reconnector = self._get_host_liveness_state(host).down_epoch == down_epoch + if start_reconnector: + self._start_reconnector(host, is_host_addition, expected_down_epoch=down_epoch) + else: + log.debug("Not starting reconnector for removed host %s", host) + finally: + pending_up_epoch = None + with host.lock: + if self._clear_down_handling(host, down_epoch): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) - self._start_reconnector(host, is_host_addition) + self._handle_pending_node_up(host, pending_up_epoch) def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ @@ -2123,6 +2335,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): with host.lock: was_up = host.is_up + state = self._get_host_liveness_state(host) # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated @@ -2136,12 +2349,34 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if connected: return + if not expect_host_to_be_down: + if was_up is False: + if state.pending_up_epoch is not None: + state.advance() + state.pending_up_epoch = None + host.set_down() + return + + state.advance() + state.pending_up_epoch = None host.set_down() - if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + down_epoch = state.epoch + if state.down_epoch is not None: + return + if (host.is_currently_reconnecting() and + state.up_epoch is None): return + state.down_epoch = down_epoch log.warning("Host %s has been marked down", host) - self.on_down_potentially_blocking(host, is_host_addition) + future = self.on_down_potentially_blocking( + host, is_host_addition, down_epoch) + if future is None: + pending_up_epoch = None + with host.lock: + if self._clear_down_handling(host, down_epoch): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + self._handle_pending_node_up(host, pending_up_epoch) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2218,7 +2453,12 @@ def on_remove(self, host): return log.debug("[cluster] Removing host %s", host) - host.set_down() + with host.lock: + state = self._get_host_liveness_state(host) + state.advance() + state.pending_up_epoch = None + state.down_epoch = None + host.set_down() self.profile_manager.on_remove(host) for session in tuple(self.sessions): session.on_remove(host) @@ -2230,9 +2470,16 @@ def on_remove(self, host): if reconnection_handler: reconnection_handler.cancel() + @staticmethod + def _is_authentication_failure(connection_exc): + return (isinstance(connection_exc, AuthenticationFailed) or + isinstance(getattr(connection_exc, "__cause__", None), AuthenticationFailed)) + def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: + if host.is_up is None and self._is_authentication_failure(connection_exc): + return is_down self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down @@ -3344,6 +3591,7 @@ def run_add_or_renew_pool(): new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc.__cause__ = auth_exc self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: @@ -3385,8 +3633,13 @@ def callback(pool, errors): return self.submit(run_add_or_renew_pool) - def remove_pool(self, host): - pool = self._pools.pop(host, None) + def remove_pool(self, host, expected_host=None): + with self._lock: + pool = self._pools.get(host) + if expected_host is not None and pool is not None and pool.host is not expected_host: + return None + if pool is not None: + self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..fe600a1ad7 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -163,8 +163,6 @@ class Host(object): _reconnection_handler = None lock = None - _currently_handling_node_up = False - sharding_info = None def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): @@ -920,5 +918,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..f1ebcfe1d2 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -16,13 +16,16 @@ import logging import socket -from unittest.mock import patch, Mock +from concurrent.futures import Future +from threading import Lock +from unittest.mock import patch, Mock, ANY import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import ConnectionException, DefaultEndPoint from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -33,6 +36,37 @@ log = logging.getLogger(__name__) + +class _ImmediateExecutor(object): + + def submit(self, fn, *args, **kwargs): + future = Future() + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + +class _QueuedExecutor(object): + + def __init__(self): + self.submissions = [] + + def submit(self, fn, *args, **kwargs): + future = Future() + self.submissions.append((future, fn, args, kwargs)) + return future + + def run_next(self): + future, fn, args, kwargs = self.submissions.pop(0) + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -246,6 +280,855 @@ def test_event_delay_timing(self, *_): sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" +class HostStateRaceTest(unittest.TestCase): + + @staticmethod + def _make_host(): + return Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + @staticmethod + def _make_cluster(session=None, listener=None): + cluster = Cluster.__new__(Cluster) + cluster.is_shutdown = False + cluster.profile_manager = Mock() + cluster.control_connection = Mock() + cluster.sessions = set([session] if session else []) + cluster._listeners = set([listener] if listener else []) + cluster._listener_lock = Lock() + cluster.executor = _ImmediateExecutor() + cluster._start_reconnector = Mock() + cluster._discount_down_events = False + return cluster + + @staticmethod + def _make_session_with_pool(host, pool): + session = Session.__new__(Session) + session._lock = Lock() + session._pools = {host: pool} + session.submit = _ImmediateExecutor().submit + return session + + @staticmethod + def _state(cluster, host): + return cluster._get_host_liveness_state(host) + + @classmethod + def _reserve_down_handling(cls, cluster, host): + with host.lock: + state = cls._state(cluster, host) + state.epoch += 1 + state.pending_up_epoch = None + host.set_down() + state.down_epoch = state.epoch + return state.down_epoch + + def test_stale_down_handling_is_ignored_after_host_is_up(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + with host.lock: + state = self._state(cluster, host) + state.down_epoch = None + state.epoch += 1 + host.set_up() + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_unreserved_down_handling_is_ignored_during_host_up_handling(self): + session = Mock() + cluster = self._make_cluster(session=session) + host = self._make_host() + host.set_down() + self._state(cluster, host).up_epoch = 0 + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False) + + session.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_noop_down_during_up_handling_does_not_supersede_up(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + up_epoch = state.up_epoch + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert state.epoch == up_epoch + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + assert host.is_up + assert state.up_epoch is None + + def test_newer_forced_down_during_up_handling_is_preserved(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + assert session.remove_pool.call_count == 1 + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert state.epoch > first_up_epoch + assert state.up_epoch == first_up_epoch + assert not host.is_up + + pool_future.set_result(True) + + listener.on_up.assert_not_called() + assert session.remove_pool.call_count == 2 + assert not host.is_up + assert state.up_epoch is None + + def test_stale_failed_up_callback_does_not_cleanup_newer_down(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + pool_future.set_exception(RuntimeError("pool failed after newer down")) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + listener.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + + def test_forced_down_during_up_handling_is_not_hidden_by_reconnector(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_down() + old_reconnector = Mock() + host._reconnection_handler = old_reconnector + original_get_reconnector = Cluster._get_reconnector_for_current_up_handling + + def force_down_before_reconnector_is_cleared(h, up_epoch): + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + return original_get_reconnector(cluster, h, up_epoch) + + cluster._get_reconnector_for_current_up_handling = Mock( + side_effect=force_down_before_reconnector_is_cleared) + + Cluster.on_up(cluster, host) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + old_reconnector.cancel.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_newer_down_before_up_side_effects_suppresses_stale_up(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + original_superseded_check = Cluster._up_handling_was_superseded + checked = [] + + def force_down_before_first_superseded_check(h, up_epoch): + if not checked: + checked.append(True) + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + return original_superseded_check(cluster, h, up_epoch) + + cluster._up_handling_was_superseded = Mock( + side_effect=force_down_before_first_superseded_check) + + Cluster.on_up(cluster, host) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_up_during_down_superseding_in_flight_up_is_replayed(self): + first_pool_future = Future() + second_pool_future = Future() + session = Mock() + session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + assert state.epoch > first_up_epoch + assert state.up_epoch == first_up_epoch + assert state.pending_up_epoch == state.epoch + + first_pool_future.set_result(True) + + assert session.add_or_renew_pool.call_count == 2 + assert state.up_epoch == state.epoch + assert state.pending_up_epoch is None + listener.on_up.assert_not_called() + + second_pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + assert host.is_up + assert state.up_epoch is None + + def test_superseded_up_cleanup_precedes_replayed_up_pool_creation(self): + first_pool_future = Future() + second_pool_future = Future() + session = Mock() + session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cleanup_calls = [] + + def signal_up_during_first_cleanup(h, **kwargs): + if cleanup_calls: + return None + cleanup_calls.append(h) + Cluster.on_up(cluster, h) + assert session.add_or_renew_pool.call_count == 1 + assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch + return None + + session.remove_pool.side_effect = signal_up_during_first_cleanup + + first_pool_future.set_result(True) + + assert cleanup_calls == [host] + assert session.add_or_renew_pool.call_count == 2 + assert self._state(cluster, host).up_epoch == self._state(cluster, host).epoch + assert self._state(cluster, host).pending_up_epoch is None + + second_pool_future.set_result(True) + + assert host.is_up + assert self._state(cluster, host).up_epoch is None + + def test_sync_up_failure_replays_queued_up(self): + session = Mock() + session.add_or_renew_pool.return_value = None + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + fail_first_on_up = [True] + + def queue_up_then_fail(h): + if not fail_first_on_up[0]: + return + fail_first_on_up[0] = False + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + Cluster.on_up(cluster, h) + state = self._state(cluster, h) + assert state.pending_up_epoch == state.epoch + raise RuntimeError("up failed") + + cluster.profile_manager.on_up.side_effect = queue_up_then_fail + + with pytest.raises(RuntimeError): + Cluster.on_up(cluster, host) + + assert cluster.profile_manager.on_up.call_count == 2 + cluster.control_connection.on_up.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with( + host, is_host_addition=False) + assert host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_old_up_callback_does_not_clear_replayed_up_handling(self): + first_pool_future = Future() + second_pool_future = Future() + session = Mock() + session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + listener.on_up.side_effect = lambda h: Cluster.on_down( + cluster, h, is_host_addition=False) + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + + first_pool_future.set_result(True) + + assert session.add_or_renew_pool.call_count == 2 + assert not host.is_up + assert state.up_epoch != first_up_epoch + + listener.on_up.side_effect = None + second_pool_future.set_result(True) + + assert listener.on_up.call_count == 2 + assert host.is_up + assert state.up_epoch is None + + def test_superseded_up_cleanup_preserves_replacement_host_pool(self): + stale_host = self._make_host() + replacement_host = self._make_host() + replacement_pool = Mock(host=replacement_host) + session = self._make_session_with_pool(replacement_host, replacement_pool) + cluster = self._make_cluster(session=session) + + assert stale_host == replacement_host + assert stale_host.host_id != replacement_host.host_id + + cluster._cleanup_superseded_up_handling(stale_host) + + assert session._pools.get(replacement_host) is replacement_pool + replacement_pool.shutdown.assert_not_called() + + def test_superseded_up_cleanup_removes_matching_host_pool(self): + host = self._make_host() + pool = Mock(host=host) + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + + cluster._cleanup_superseded_up_handling(host) + + assert session._pools == {} + pool.shutdown.assert_called_once_with() + + def test_down_during_up_listener_is_handled(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + listener.on_up.side_effect = lambda h: Cluster.on_down( + cluster, h, is_host_addition=False) + + Cluster.on_up(cluster, host) + assert self._state(cluster, host).up_epoch is not None + + pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_current_down_handling_still_removes_pools_and_reconnects(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert self._state(cluster, host).down_epoch is None + + def test_remove_during_down_listener_does_not_start_reconnector(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.metadata = Mock() + cluster.metadata.remove_host.return_value = True + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + listener.on_down.side_effect = lambda h: Cluster.remove_host(cluster, h) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.metadata.remove_host.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + listener.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_queued_down_handling_after_remove_does_not_start_reconnector(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.metadata = Mock() + cluster.metadata.remove_host.return_value = True + host = self._make_host() + host.set_up() + + Cluster.on_down(cluster, host, is_host_addition=False) + assert len(executor.submissions) == 1 + + Cluster.remove_host(cluster, host) + executor.run_next() + + cluster.metadata.remove_host.assert_called_once_with(host) + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_start_reconnector_rechecks_down_epoch_before_installing_handler(self): + cluster = self._make_cluster() + cluster.reconnection_policy = Mock() + cluster.reconnection_policy.new_schedule.return_value = iter([0]) + cluster.scheduler = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + state = self._state(cluster, host) + state.down_epoch = 1 + + def clear_down_epoch(h): + with h.lock: + self._state(cluster, h).down_epoch = None + return Mock() + + cluster._make_connection_factory = Mock(side_effect=clear_down_epoch) + Cluster._start_reconnector( + cluster, host, is_host_addition=False, expected_down_epoch=1) + + assert host._reconnection_handler is None + cluster.scheduler.schedule.assert_not_called() + + def test_on_up_queues_after_down_is_submitted_before_worker_runs(self): + executor = _QueuedExecutor() + session = Mock() + session.add_or_renew_pool.return_value = None + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + Cluster.on_down(cluster, host, is_host_addition=False) + state = self._state(cluster, host) + + assert len(executor.submissions) == 1 + assert state.down_epoch == state.epoch + + Cluster.on_up(cluster, host) + + assert state.pending_up_epoch == state.epoch + assert state.up_epoch is None + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + + executor.run_next() + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): + executor = _QueuedExecutor() + session = Mock() + session.add_or_renew_pool.return_value = None + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + state = self._state(cluster, host) + + host.endpoint = DefaultEndPoint("127.0.0.2") + + assert self._state(cluster, host) is state + + Cluster.on_up(cluster, host) + + assert state.down_epoch == state.epoch + assert state.pending_up_epoch == state.epoch + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + + executor.run_next() + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_up_signal_waits_until_submitted_down_handling_finishes(self): + executor = _QueuedExecutor() + events = [] + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + cluster.profile_manager.on_down.side_effect = lambda h: events.append("profile_down") + cluster.control_connection.on_down.side_effect = lambda h: events.append("control_down") + session.on_down.side_effect = lambda h: events.append("session_down") + listener.on_down.side_effect = lambda h: events.append("listener_down") + cluster._start_reconnector.side_effect = lambda h, is_host_addition, **kwargs: events.append("reconnector") + session.remove_pool.side_effect = lambda h, **kwargs: events.append("remove_pool") + cluster.profile_manager.on_up.side_effect = lambda h: events.append("profile_up") + cluster.control_connection.on_up.side_effect = lambda h: events.append("control_up") + session.add_or_renew_pool.side_effect = lambda h, is_host_addition: events.append("add_pool") + + Cluster.on_down(cluster, host, is_host_addition=False) + Cluster.on_up(cluster, host) + + assert events == [] + + executor.run_next() + + assert events == [ + "profile_down", + "control_down", + "session_down", + "listener_down", + "reconnector", + "remove_pool", + "profile_up", + "control_up", + "add_pool", + ] + assert host.is_up + + def test_on_up_queues_when_down_handling_is_active(self): + cluster = self._make_cluster() + cluster._prepare_all_queries = Mock() + host = self._make_host() + host.set_down() + down_epoch = self._reserve_down_handling(cluster, host) + + Cluster.on_up(cluster, host) + + cluster._prepare_all_queries.assert_not_called() + state = self._state(cluster, host) + assert state.down_epoch == down_epoch + assert state.up_epoch is None + assert state.pending_up_epoch == state.epoch + + def test_on_up_during_down_handling_is_replayed_for_ignored_host(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert self._state(cluster, host).down_epoch is None + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_later_down_during_down_handling_invalidates_queued_up(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + def queue_up_then_down(h): + Cluster.on_up(cluster, h) + assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch + Cluster.on_down(cluster, h, is_host_addition=False) + + listener.on_down.side_effect = queue_up_then_down + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).down_epoch is None + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_remove_during_down_handling_invalidates_queued_up(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + def queue_up_then_remove(h): + Cluster.on_up(cluster, h) + assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch + Cluster.on_remove(cluster, h) + + listener.on_down.side_effect = queue_up_then_remove + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_stale_queued_up_replay_is_ignored_after_newer_down_event(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 1 + state.pending_up_epoch = 0 + + cluster._handle_pending_node_up(host, 0) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert state.up_epoch is None + + def test_stale_queued_up_replay_preserves_newer_pending_up(self): + cluster = self._make_cluster() + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.down_epoch = 2 + state.pending_up_epoch = 2 + + cluster._handle_pending_node_up(host, 1) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert state.pending_up_epoch == 2 + assert state.up_epoch is None + + def test_down_after_pending_up_pop_invalidates_replay(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.pending_up_epoch = 2 + + with host.lock: + pending_up_epoch = cluster._pop_pending_node_up_if_ready(host) + + assert pending_up_epoch == 2 + assert state.pending_up_epoch == 2 + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert state.epoch == 3 + assert state.pending_up_epoch is None + + cluster._handle_pending_node_up(host, pending_up_epoch) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + + def test_down_for_down_host_with_pending_up_only_invalidates_pending_up(self): + cluster = self._make_cluster() + cluster.executor = Mock() + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.down_epoch = 2 + state.pending_up_epoch = 2 + + Cluster.on_down(cluster, host, is_host_addition=False) + + cluster.executor.submit.assert_not_called() + assert state.down_epoch == 2 + assert state.pending_up_epoch is None + assert state.epoch == 3 + + def test_auth_failure_for_unknown_host_does_not_start_down_handling(self): + cluster = self._make_cluster() + host = self._make_host() + + is_down = cluster.signal_connection_failure( + host, AuthenticationFailed("bad credentials"), is_host_addition=False) + + assert is_down + assert host.is_up is None + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_wrapped_auth_failure_for_unknown_host_does_not_start_down_handling(self): + cluster = self._make_cluster() + host = self._make_host() + auth_exc = AuthenticationFailed("bad credentials") + conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc.__cause__ = auth_exc + + is_down = cluster.signal_connection_failure( + host, conn_exc, is_host_addition=False) + + assert is_down + assert host.is_up is None + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_real_down_for_unknown_host_marks_host_down(self): + cluster = self._make_cluster() + host = self._make_host() + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert host.is_up is False + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + + def test_expected_down_for_unknown_host_marks_host_down(self): + cluster = self._make_cluster() + host = self._make_host() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + assert host.is_up is False + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + + class SessionTest(unittest.TestCase): def setUp(self): if connection_class is None: From 72c4f4ccd8a2219663250cef758f207c02ea0a70 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 2 May 2026 09:15:02 -0400 Subject: [PATCH 03/29] session: fence pool creation and stale endpoints --- cassandra/cluster.py | 514 +++++++++++++++++++++++------ cassandra/pool.py | 18 +- tests/unit/test_cluster.py | 645 ++++++++++++++++++++++++++++++++++++- 3 files changed, 1073 insertions(+), 104 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 526f971f8c..51651679a1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -328,6 +328,25 @@ def pending_up_epoch(self, epoch): self._set_or_clear_event(self._PENDING_UP, epoch) +class _PoolCreationState(_EventFenceState): + _CREATE = "create" + + __slots__ = ("future", "endpoint") + + def __init__(self): + _EventFenceState.__init__(self) + self.future = None + self.endpoint = None + + @property + def creation_epoch(self): + return self.get_event(self._CREATE) + + @creation_epoch.setter + def creation_epoch(self, epoch): + self._set_or_clear_event(self._CREATE, epoch) + + class _IdentityWeakKeyDictionary(object): """ Weak mapping that uses object identity instead of ``__hash__``/``__eq__``. @@ -1995,13 +2014,27 @@ def _session_register_user_types(self, session): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) - def _cleanup_failed_on_up_handling(self, host): - self.profile_manager.on_down(host) - self.control_connection.on_down(host) + def _cleanup_failed_on_up_handling(self, host, start_reconnector=True, expected_endpoint=None): + endpoint_changed = False + if expected_endpoint is None: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + else: + with host.lock: + endpoint_changed = host.endpoint != expected_endpoint + if endpoint_changed: + log.debug("Not signalling down for stale up handling on node %s; endpoint changed from %s", + host, expected_endpoint) + else: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) for session in tuple(self.sessions): - session.remove_pool(host, expected_host=host) + session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) - self._start_reconnector(host, is_host_addition=False) + if start_reconnector and not endpoint_changed: + self._start_reconnector( + host, is_host_addition=False, expected_endpoint=expected_endpoint) def _get_host_liveness_state(self, host): try: @@ -2034,9 +2067,10 @@ def _clear_up_handling(self, host, up_epoch=None): state = self._get_host_liveness_state(host) return state.clear_event(_HostLivenessState._UP, up_epoch) - def _cleanup_superseded_up_handling(self, host): + def _cleanup_superseded_up_handling(self, host, expected_endpoint=None): for session in tuple(self.sessions): - session.remove_pool(host, expected_host=host) + session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) def _pop_pending_node_up_if_ready(self, host): state = self._get_host_liveness_state(host) @@ -2062,8 +2096,9 @@ def _clear_down_handling(self, host, down_epoch=None): state = self._get_host_liveness_state(host) return state.clear_event(_HostLivenessState._DOWN, down_epoch) - def _finish_superseded_up_handling(self, host, up_epoch): - self._cleanup_superseded_up_handling(host) + def _finish_superseded_up_handling(self, host, up_epoch, expected_endpoint=None): + self._cleanup_superseded_up_handling( + host, expected_endpoint=expected_endpoint) pending_up_epoch = None with host.lock: @@ -2072,13 +2107,15 @@ def _finish_superseded_up_handling(self, host, up_epoch): self._handle_pending_node_up(host, pending_up_epoch) - def _finish_up_if_superseded(self, host, up_epoch): + def _finish_up_if_superseded(self, host, up_epoch, expected_endpoint=None): if self._up_handling_is_superseded(host, up_epoch): - self._finish_superseded_up_handling(host, up_epoch) + self._finish_superseded_up_handling( + host, up_epoch, expected_endpoint=expected_endpoint) return True return False - def _on_up_future_completed(self, host, up_handling_revision, futures, results, lock, finished_future): + def _on_up_future_completed(self, host, up_handling_revision, up_handling_endpoint, + futures, results, lock, finished_future): with lock: futures.discard(finished_future) @@ -2092,17 +2129,29 @@ def _on_up_future_completed(self, host, up_handling_revision, futures, results, try: # all futures have completed at this point - if self._finish_up_if_superseded(host, up_handling_revision): + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): return for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) - self._cleanup_failed_on_up_handling(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return + self._cleanup_failed_on_up_handling( + host, expected_endpoint=up_handling_endpoint) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) - self._cleanup_failed_on_up_handling(host) + with host.lock: + start_reconnector = host.is_up is not None + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return + self._cleanup_failed_on_up_handling( + host, start_reconnector=start_reconnector, + expected_endpoint=up_handling_endpoint) return log.info("Connection pools established for node %s", host) @@ -2116,7 +2165,8 @@ def _on_up_future_completed(self, host, up_handling_revision, futures, results, host.set_up() self._clear_up_handling(host, up_handling_revision) if superseded: - self._finish_superseded_up_handling(host, up_handling_revision) + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) return for listener in self.listeners: listener.on_up(host) @@ -2172,6 +2222,7 @@ def _on_up(self, host, expected_epoch=None): state.pending_up_epoch = None up_handling_revision = state.epoch + up_handling_endpoint = host.endpoint state.up_epoch = up_handling_revision log.debug("Starting to handle up status of node %s", host) @@ -2183,7 +2234,8 @@ def _on_up(self, host, expected_epoch=None): reconnector, superseded = self._get_reconnector_for_current_up_handling( host, up_handling_revision) if superseded: - self._finish_superseded_up_handling(host, up_handling_revision) + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) return futures if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) @@ -2191,32 +2243,38 @@ def _on_up(self, host, expected_epoch=None): if self.profile_manager.distance(host) != HostDistance.IGNORED: self._prepare_all_queries(host) - if self._finish_up_if_superseded(host, up_handling_revision): + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): return futures log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): - session.remove_pool(host, expected_host=host) + session.remove_pool( + host, expected_host=host, expected_endpoint=up_handling_endpoint) - if self._finish_up_if_superseded(host, up_handling_revision): + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): return futures log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) - if self._finish_up_if_superseded(host, up_handling_revision): + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): return futures log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) - if self._finish_up_if_superseded(host, up_handling_revision): + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): return futures log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, up_handling_revision, + up_handling_endpoint, futures, futures_results, futures_lock) for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) @@ -2229,7 +2287,8 @@ def _on_up(self, host, expected_epoch=None): for future in futures: future.cancel() - self._cleanup_failed_on_up_handling(host) + self._cleanup_failed_on_up_handling( + host, expected_endpoint=up_handling_endpoint) pending_up_epoch = None with host.lock: @@ -2248,34 +2307,46 @@ def _on_up(self, host, expected_epoch=None): host.set_up() self._clear_up_handling(host, up_handling_revision) if superseded: - self._finish_superseded_up_handling(host, up_handling_revision) + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) # for testing purposes return futures - def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None): + def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None, + expected_endpoint=None): if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() - # in order to not hold references to this Cluster open and prevent - # proper shutdown when the program ends, we'll just make a closure - # of the current Cluster attributes to create new Connections with - conn_factory = self._make_connection_factory(host) + with host.lock: + if expected_endpoint is not None and host.endpoint != expected_endpoint: + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + return - reconnector = _HostReconnectionHandler( - host, conn_factory, is_host_addition, self.on_add, self.on_up, - self.scheduler, schedule, host.get_and_set_reconnection_handler, - new_handler=None) + # in order to not hold references to this Cluster open and prevent + # proper shutdown when the program ends, we'll just make a closure + # of the current Cluster attributes to create new Connections with + conn_factory = self._make_connection_factory(host) - with host.lock: if expected_down_epoch is not None: state = self._get_host_liveness_state(host) if state.down_epoch != expected_down_epoch: log.debug("Not starting reconnector for host %s; down handling is no longer current", host) return + if expected_endpoint is not None and host.endpoint != expected_endpoint: + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + return + + reconnector = _HostReconnectionHandler( + host, conn_factory, is_host_addition, self.on_add, self.on_up, + self.scheduler, schedule, host.get_and_set_reconnection_handler, + new_handler=None) + old_reconnector = host._reconnection_handler host._reconnection_handler = reconnector if old_reconnector: @@ -2288,45 +2359,88 @@ def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None): @run_in_executor def on_down_potentially_blocking( self, host: Host, is_host_addition: bool, - down_epoch: Optional[int] = None) -> Any: + down_epoch: Optional[int] = None, + expected_endpoint: Optional[EndPoint] = None) -> Any: pending_up_epoch = None - with host.lock: - state = self._get_host_liveness_state(host) - owns_reserved_down_handling = down_epoch is not None and state.down_epoch == down_epoch - if down_epoch is None: - if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + try: + with host.lock: + state = self._get_host_liveness_state(host) + owns_reserved_down_handling = down_epoch is not None and state.down_epoch == down_epoch + if down_epoch is None: + if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + log.debug("Ignoring stale down handling for host %s", host) + return + down_epoch = state.epoch + state.down_epoch = down_epoch + elif not owns_reserved_down_handling: log.debug("Ignoring stale down handling for host %s", host) return - down_epoch = state.epoch - state.down_epoch = down_epoch - elif not owns_reserved_down_handling: - log.debug("Ignoring stale down handling for host %s", host) - return + endpoint_matches = expected_endpoint is None or host.endpoint == expected_endpoint + + if endpoint_matches: + if expected_endpoint is None: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + else: + with host.lock: + if host.endpoint != expected_endpoint: + log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", + host, expected_endpoint) + endpoint_matches = False + else: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + else: + log.debug("Not signalling down for stale down handling on node %s; endpoint changed from %s", + host, expected_endpoint) - try: - self.profile_manager.on_down(host) - self.control_connection.on_down(host) for session in tuple(self.sessions): - session.on_down(host) + if expected_endpoint is None: + session.on_down(host) + else: + session.on_down(host, expected_endpoint=expected_endpoint) - for listener in self.listeners: - listener.on_down(host) + if endpoint_matches: + if expected_endpoint is None: + for listener in self.listeners: + listener.on_down(host) + else: + with host.lock: + if host.endpoint != expected_endpoint: + log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", + host, expected_endpoint) + endpoint_matches = False + else: + for listener in self.listeners: + listener.on_down(host) with host.lock: - start_reconnector = self._get_host_liveness_state(host).down_epoch == down_epoch + start_reconnector = (endpoint_matches and + self._get_host_liveness_state(host).down_epoch == down_epoch) + if (start_reconnector and expected_endpoint is not None and + host.endpoint != expected_endpoint): + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + start_reconnector = False if start_reconnector: - self._start_reconnector(host, is_host_addition, expected_down_epoch=down_epoch) + if expected_endpoint is None: + self._start_reconnector(host, is_host_addition, expected_down_epoch=down_epoch) + else: + self._start_reconnector( + host, is_host_addition, expected_down_epoch=down_epoch, + expected_endpoint=expected_endpoint) else: log.debug("Not starting reconnector for removed host %s", host) finally: pending_up_epoch = None with host.lock: - if self._clear_down_handling(host, down_epoch): + if down_epoch is not None and self._clear_down_handling(host, down_epoch): pending_up_epoch = self._pop_pending_node_up_if_ready(host) self._handle_pending_node_up(host, pending_up_epoch) - def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + def on_down(self, host, is_host_addition, expect_host_to_be_down=False, + expected_endpoint=None): """ Intended for internal use only. """ @@ -2334,6 +2448,11 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): return with host.lock: + if expected_endpoint is not None and host.endpoint != expected_endpoint: + log.debug("Ignoring stale down signal for host %s; endpoint changed from %s", + host, expected_endpoint) + return + was_up = host.is_up state = self._get_host_liveness_state(host) @@ -2363,14 +2482,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): down_epoch = state.epoch if state.down_epoch is not None: return - if (host.is_currently_reconnecting() and + if (not expect_host_to_be_down and + host.is_currently_reconnecting() and state.up_epoch is None): return state.down_epoch = down_epoch log.warning("Host %s has been marked down", host) future = self.on_down_potentially_blocking( - host, is_host_addition, down_epoch) + host, is_host_addition, down_epoch, expected_endpoint) if future is None: pending_up_epoch = None with host.lock: @@ -2475,12 +2595,22 @@ def _is_authentication_failure(connection_exc): return (isinstance(connection_exc, AuthenticationFailed) or isinstance(getattr(connection_exc, "__cause__", None), AuthenticationFailed)) - def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): - is_down = host.signal_connection_failure(connection_exc) + def signal_connection_failure(self, host, connection_exc, is_host_addition, + expect_host_to_be_down=False, expected_endpoint=None): + with host.lock: + if expected_endpoint is not None and host.endpoint != expected_endpoint: + log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", + host, expected_endpoint) + return False + + is_down = host.signal_connection_failure(connection_exc) + if is_down: + if host.is_up is None and self._is_authentication_failure(connection_exc): + return is_down if is_down: - if host.is_up is None and self._is_authentication_failure(connection_exc): - return is_down - self.on_down(host, is_host_addition, expect_host_to_be_down) + self.on_down( + host, is_host_addition, expect_host_to_be_down, + expected_endpoint=expected_endpoint) return is_down def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): @@ -2961,6 +3091,7 @@ def __init__(self, cluster, hosts, keyspace=None): self._lock = RLock() self._pools = {} + self._pool_creation_fences = _EventFenceMap(_PoolCreationState) self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] @@ -3578,6 +3709,37 @@ def __del__(self): # when cluster.shutdown() is called explicitly. pass + def _get_pool_creation_state(self, host): + try: + fences = self._pool_creation_fences + except AttributeError: + fences = self._pool_creation_fences = _EventFenceMap(_PoolCreationState) + return fences.get_state(host) + + def _pool_creation_is_current(self, host, creation_epoch): + state = self._get_pool_creation_state(host) + return state.event_is_current(_PoolCreationState._CREATE, creation_epoch) + + def _clear_pool_creation(self, host, creation_epoch): + state = self._get_pool_creation_state(host) + if state.clear_event(_PoolCreationState._CREATE, creation_epoch): + state.future = None + state.endpoint = None + return True + return False + + def _invalidate_pool_creation(self, host, expected_endpoint=None): + state = self._get_pool_creation_state(host) + if state.creation_epoch is not None: + if expected_endpoint is not None and state.endpoint != expected_endpoint: + return False + state.advance() + state.creation_epoch = None + state.future = None + state.endpoint = None + return True + return False + def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. @@ -3586,26 +3748,57 @@ def add_or_renew_pool(self, host, is_host_addition): if distance == HostDistance.IGNORED: return None + creation_epoch = None + creation_endpoint = None + def run_add_or_renew_pool(): try: - new_pool = HostConnection(host, distance, self) + new_pool = HostConnection( + host, distance, self, endpoint=creation_endpoint) except AuthenticationFailed as auth_exc: - conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc = ConnectionException(str(auth_exc), endpoint=creation_endpoint) conn_exc.__cause__ = auth_exc - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + with self._lock: + signal_failure = self._pool_creation_is_current(host, creation_epoch) + if signal_failure: + self._clear_pool_creation(host, creation_epoch) + if signal_failure: + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, + expected_endpoint=creation_endpoint) return False - except Exception as conn_exc: + except Exception as pool_exc: log.warning("Failed to create connection pool for new host %s:", - host, exc_info=conn_exc) + host, exc_info=pool_exc) + conn_exc = pool_exc + try: + conn_exc.endpoint = creation_endpoint + except AttributeError: + conn_exc = ConnectionException(str(pool_exc), endpoint=creation_endpoint) + conn_exc.__cause__ = pool_exc # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created - self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + with self._lock: + signal_failure = self._pool_creation_is_current(host, creation_epoch) + if signal_failure: + self._clear_pool_creation(host, creation_epoch) + if signal_failure: + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, expect_host_to_be_down=True, + expected_endpoint=creation_endpoint) return False + pool_endpoint = self._get_pool_endpoint(new_pool, creation_endpoint) + new_pool.endpoint = pool_endpoint - previous = self._pools.get(host) + previous = None + discard_pool = False + signal_down = False with self._lock: while new_pool._keyspace != self.keyspace: + if not self._pool_creation_is_current(host, creation_epoch): + discard_pool = True + break + self._lock.release() set_keyspace_event = Event() errors_returned = [] @@ -3616,14 +3809,34 @@ def callback(pool, errors): new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) + self._lock.acquire() + if not self._pool_creation_is_current(host, creation_epoch): + discard_pool = True + break if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) - new_pool.shutdown() - self._lock.acquire() - return False - self._lock.acquire() - self._pools[host] = new_pool + self._clear_pool_creation(host, creation_epoch) + signal_down = True + break + + if not discard_pool and not signal_down: + if self._pool_creation_is_current(host, creation_epoch): + previous = self._pools.get(host) + self._pools[host] = new_pool + self._clear_pool_creation(host, creation_epoch) + else: + discard_pool = True + + if signal_down: + self.cluster.on_down( + host, is_host_addition, expected_endpoint=creation_endpoint) + new_pool.shutdown() + return False + + if discard_pool: + log.debug("Discarding stale connection pool for host %s", host) + new_pool.shutdown() + return False log.debug("Added pool for host %s to session", host) if previous: @@ -3631,21 +3844,114 @@ def callback(pool, errors): return True - return self.submit(run_add_or_renew_pool) + with self._lock: + state = self._get_pool_creation_state(host) + if state.creation_epoch is not None: + return state.future - def remove_pool(self, host, expected_host=None): + creation_epoch = state.advance() + state.creation_epoch = creation_epoch + with host.lock: + creation_endpoint = host.endpoint + state.endpoint = creation_endpoint + future = self.submit(run_add_or_renew_pool) + if future is None: + self._clear_pool_creation(host, creation_epoch) + return None + if state.event_is_current(_PoolCreationState._CREATE, creation_epoch): + state.future = future + return future + + def remove_pool(self, host, expected_host=None, expected_endpoint=None): + removed_pools = [] with self._lock: - pool = self._pools.get(host) - if expected_host is not None and pool is not None and pool.host is not expected_host: + pool = self._get_pool_by_host_identity( + host, expected_host=expected_host, expected_endpoint=expected_endpoint) + remove_by_identity = pool is not None + if pool is None: + pool = self._pools.get(host) + if pool is not None and not self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint): + self._invalidate_pool_creation( + host, expected_endpoint=expected_endpoint) return None + self._invalidate_pool_creation( + host, expected_endpoint=expected_endpoint) if pool is not None: - self._pools.pop(host, None) - if pool: + if remove_by_identity: + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if (pool_host is host and self._pool_matches_expected( + host_pool, expected_host=expected_host, + expected_endpoint=expected_endpoint)): + removed_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + self._pools = retained_pools + else: + removed_pool = self._pools.pop(host, None) + if removed_pool is not None: + removed_pools.append(removed_pool) + if removed_pools: log.debug("Removed connection pool for %r", host) - return self.submit(pool.shutdown) + return self.submit(self._shutdown_removed_pools, removed_pools) else: return None + @staticmethod + def _shutdown_removed_pools(pools): + for pool in pools: + pool.shutdown() + + def _pool_matches_expected(self, pool, expected_host=None, expected_endpoint=None): + if expected_host is not None and pool.host is not expected_host: + return False + if expected_endpoint is not None: + pool_endpoint = getattr(pool, 'endpoint', None) + if pool_endpoint is None: + pool_endpoint = pool.host.endpoint + if pool_endpoint != expected_endpoint: + return False + return True + + @staticmethod + def _get_pool_endpoint(pool, default_endpoint): + endpoint = getattr(pool, 'endpoint', None) + connections = getattr(pool, '_connections', None) + if connections: + for connection in connections.values(): + if getattr(connection, 'original_endpoint', None) is None: + connection_endpoint = getattr(connection, 'endpoint', None) + if connection_endpoint is not None: + return connection_endpoint + elif endpoint is None: + endpoint = connection.original_endpoint + return endpoint if endpoint is not None else default_endpoint + + def _get_pool_by_host_identity(self, host, expected_host=None, expected_endpoint=None): + for pool_host, pool in self._pools.items(): + if (pool_host is host and self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): + return pool + return None + + def _reuse_or_invalidate_pool_creation(self, host, pool_creation_future): + with self._lock: + current_distance = self._profile_manager.distance(host) + pool_creation_state = self._get_pool_creation_state(host) + if (pool_creation_state.creation_epoch is not None and + pool_creation_state.future is pool_creation_future): + with host.lock: + endpoint_changed = host.endpoint != pool_creation_state.endpoint + if endpoint_changed: + self._invalidate_pool_creation( + host, expected_endpoint=pool_creation_state.endpoint) + if current_distance == HostDistance.IGNORED: + self._invalidate_pool_creation(host) + elif not endpoint_changed: + return pool_creation_future + return None + def update_created_pools(self): """ When the set of live nodes change, the loadbalancer will change its @@ -3661,14 +3967,27 @@ def update_created_pools(self): futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) - pool = self._pools.get(host) + with self._lock: + pool = self._pools.get(host) + pool_creation_state = self._get_pool_creation_state(host) + pool_creation_future = ( + pool_creation_state.future + if pool_creation_state.creation_epoch is not None else None) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): - future = self.add_or_renew_pool(host, False) + if distance != HostDistance.IGNORED: + if pool_creation_future is not None: + # on_up() keeps host.is_up False until this future succeeds. + future = self._reuse_or_invalidate_pool_creation( + host, pool_creation_future) + elif host.is_up in (True, None): + future = self.add_or_renew_pool(host, False) + elif pool_creation_future is not None: + future = self._reuse_or_invalidate_pool_creation( + host, pool_creation_future) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: @@ -3679,12 +3998,12 @@ def update_created_pools(self): futures.add(future) return futures - def on_down(self, host): + def on_down(self, host, expected_endpoint=None): """ Called by the parent Cluster instance when a node is marked down. Only intended for internal use. """ - future = self.remove_pool(host) + future = self.remove_pool(host, expected_endpoint=expected_endpoint) if future: future.add_done_callback(lambda f: self.update_created_pools()) @@ -4244,11 +4563,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) - - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) + with host.lock: + old_endpoint = host.endpoint + self._cluster.on_down( + host, is_host_addition=False, expect_host_to_be_down=True, + expected_endpoint=old_endpoint) + + with host.lock: + if host.endpoint != old_endpoint: + log.debug("[control connection] Not updating host ip from %s to %s for (%s); " + "endpoint changed to %s", + old_endpoint, endpoint, host_id, host.endpoint) + continue + host.endpoint = endpoint + self._cluster.metadata.update_host(host, old_endpoint) self._cluster.on_up(host) if host is None: diff --git a/cassandra/pool.py b/cassandra/pool.py index fe600a1ad7..bdc1d07ea1 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -377,6 +377,7 @@ class HostConnection(object): """ host = None + endpoint = None host_distance = None is_shutdown = False shutdown_on_error = False @@ -391,8 +392,9 @@ class HostConnection(object): tablets_routing_v1 = False - def __init__(self, host, host_distance, session): + def __init__(self, host, host_distance, session, endpoint=None): self.host = host + self.endpoint = endpoint if endpoint is not None else host.endpoint self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() @@ -426,7 +428,7 @@ def __init__(self, host, host_distance, session): return log.debug("Initializing connection for host %s", self.host) - first_connection = session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + first_connection = session.cluster.connection_factory(self.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.features.shard_id) self._connections[first_connection.features.shard_id] = first_connection self._keyspace = session.keyspace @@ -605,13 +607,13 @@ def _replace(self, connection): self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) else: - connection = self._session.cluster.connection_factory(self.host.endpoint, + connection = self._session.cluster.connection_factory(self.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) if self._keyspace: connection.set_keyspace_blocking(self._keyspace) self._connections[connection.features.shard_id] = connection except Exception: - log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) + log.warning("Failed reconnecting %s. Retrying." % (self.endpoint,)) self._session.submit(self._replace, connection) else: self._is_replacing = False @@ -681,10 +683,10 @@ def _get_shard_aware_endpoint(self): endpoint = None if self._session.cluster.ssl_options and self.host.sharding_info.shard_aware_port_ssl: - endpoint = copy.copy(self.host.endpoint) + endpoint = copy.copy(self.endpoint) endpoint._port = self.host.sharding_info.shard_aware_port_ssl elif self.host.sharding_info.shard_aware_port: - endpoint = copy.copy(self.host.endpoint) + endpoint = copy.copy(self.endpoint) endpoint._port = self.host.sharding_info.shard_aware_port return endpoint @@ -717,12 +719,12 @@ def _open_connection_to_missing_shard(self, shard_id): conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, shard_id=shard_id, total_shards=self.host.sharding_info.shards_count) - conn.original_endpoint = self.host.endpoint + conn.original_endpoint = self.endpoint except Exception as exc: log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) raise else: - conn = self._session.cluster.connection_factory(self.host.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) + conn = self._session.cluster.connection_factory(self.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) log.debug( "Received a connection %s for shard_id=%i on host %s", diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index f1ebcfe1d2..4965f8980e 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -17,7 +17,7 @@ import socket from concurrent.futures import Future -from threading import Lock +from threading import Lock, RLock from unittest.mock import patch, Mock, ANY import uuid @@ -58,8 +58,8 @@ def submit(self, fn, *args, **kwargs): self.submissions.append((future, fn, args, kwargs)) return future - def run_next(self): - future, fn, args, kwargs = self.submissions.pop(0) + def run_next(self, index=0): + future, fn, args, kwargs = self.submissions.pop(index) try: future.set_result(fn(*args, **kwargs)) except Exception as exc: @@ -280,6 +280,494 @@ def test_event_delay_timing(self, *_): sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" +class SessionPoolRaceTest(unittest.TestCase): + + @staticmethod + def _make_host(address): + return Host(address, SimpleConvictionPolicy, host_id=uuid.uuid4()) + + @staticmethod + def _make_pool(host, distance, session, endpoint=None): + pool = Mock() + pool.host = host + pool.endpoint = endpoint if endpoint is not None else host.endpoint + pool.host_distance = distance + pool._keyspace = session.keyspace + pool.is_shutdown = False + return pool + + @staticmethod + def _make_cluster_and_session(hosts): + executor = _QueuedExecutor() + + cluster = Cluster.__new__(Cluster) + cluster.is_shutdown = False + cluster.profile_manager = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.control_connection = Mock() + cluster.metadata = Mock() + cluster.metadata.all_hosts.return_value = hosts + cluster._listeners = set() + cluster._listener_lock = Lock() + cluster.executor = executor + cluster._prepare_all_queries = Mock() + + session = Session.__new__(Session) + session.cluster = cluster + session.is_shutdown = False + session.keyspace = None + session._lock = RLock() + session._pools = {} + session._profile_manager = cluster.profile_manager + cluster.sessions = set([session]) + + return cluster, session, executor + + def test_update_created_pools_reuses_in_flight_add_for_issue_317(self): + first_host = self._make_host("127.0.0.1") + second_host = self._make_host("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session( + [first_host, second_host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + Cluster.on_add(cluster, first_host) + Cluster.on_add(cluster, second_host) + + assert len(executor.submissions) == 2 + + executor.run_next(1) + + assert second_host.is_up + assert len(executor.submissions) == 1 + + executor.run_next() + + assert len(created_pools) == 2 + assert session._pools[first_host].host is first_host + assert session._pools[second_host].host is second_host + for pool in created_pools: + pool.shutdown.assert_not_called() + + def test_add_or_renew_pool_returns_in_flight_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + + with patch("cassandra.cluster.HostConnection", + side_effect=self._make_pool): + first_future = session.add_or_renew_pool(host, is_host_addition=True) + second_future = session.add_or_renew_pool(host, is_host_addition=False) + + assert second_future is first_future + assert len(executor.submissions) == 1 + + executor.run_next() + + assert session._pools[host].host is host + + def test_update_created_pools_keeps_in_flight_up_pool_for_down_host(self): + host = self._make_host("127.0.0.1") + host.set_down() + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert session.update_created_pools() == set([future]) + + executor.run_next() + + assert future.result() is True + assert session._pools[host].host is host + created_pools[0].shutdown.assert_not_called() + + def test_update_created_pools_rechecks_distance_before_invalidating_pool_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + distance_calls = [] + future_holder = {} + + def distance(changed_host): + distance_calls.append(changed_host) + if len(distance_calls) == 1: + future_holder["future"] = session.add_or_renew_pool( + host, is_host_addition=False) + return HostDistance.IGNORED + return HostDistance.LOCAL + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + cluster.profile_manager.distance.side_effect = distance + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + assert session.update_created_pools() == set([future_holder["future"]]) + + executor.run_next() + + assert future_holder["future"].result() is True + assert session._pools[host].host is host + created_pools[0].shutdown.assert_not_called() + + def test_update_created_pools_rechecks_distance_before_reusing_pool_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + cluster.profile_manager.distance.side_effect = [ + HostDistance.LOCAL, HostDistance.IGNORED] + + assert session.update_created_pools() == set() + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_update_created_pools_invalidates_creation_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + host.endpoint = DefaultEndPoint("127.0.0.2") + + assert session.update_created_pools() == set() + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_removed_in_flight_pool_is_not_published(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_remove_pool_expected_host_mismatch_invalidates_stale_creation(self): + stale_host = self._make_host("127.0.0.1") + replacement_host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + created_pools = [] + session._pools[replacement_host] = replacement_pool + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + assert session.remove_pool( + stale_host, expected_host=stale_host) is None + + executor.run_next() + + assert future.result() is False + assert session._pools[replacement_host] is replacement_pool + replacement_pool.shutdown.assert_not_called() + created_pools[0].shutdown.assert_called_once_with() + + def test_remove_pool_finds_pool_after_host_endpoint_changes(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = pool + + host.endpoint = DefaultEndPoint("127.0.0.2") + + future = session.remove_pool(host, expected_host=host) + executor.run_next() + + assert session._pools == {} + pool.shutdown.assert_called_once_with() + assert future.done() + + def test_remove_pool_prefers_identity_after_endpoint_rewrite(self): + stale_host = self._make_host("127.0.0.1") + replacement_host = self._make_host("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + stale_pool = self._make_pool(stale_host, HostDistance.LOCAL, session) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + session._pools[stale_host] = stale_pool + + stale_host.endpoint = replacement_host.endpoint + session._pools[replacement_host] = replacement_pool + + assert stale_host == replacement_host + + session.remove_pool(stale_host, expected_host=stale_host) + executor.run_next() + + assert session._pools[replacement_host] is replacement_pool + stale_pool.shutdown.assert_called_once_with() + replacement_pool.shutdown.assert_not_called() + + def test_remove_pool_expected_endpoint_preserves_replacement_pool(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = stale_pool + + host.endpoint = DefaultEndPoint("127.0.0.2") + replacement_pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = replacement_pool + + assert len(session._pools) == 2 + + session.remove_pool( + host, expected_host=host, expected_endpoint=old_endpoint) + executor.run_next() + + assert session._pools[host] is replacement_pool + stale_pool.shutdown.assert_called_once_with() + replacement_pool.shutdown.assert_not_called() + + def test_remove_pool_expected_endpoint_preserves_replacement_creation(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + host.endpoint = DefaultEndPoint("127.0.0.2") + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert session.remove_pool( + host, expected_host=host, + expected_endpoint=old_endpoint) is None + + executor.run_next() + + assert future.result() is True + assert session._pools[host] is created_pools[0] + assert created_pools[0].endpoint == host.endpoint + created_pools[0].shutdown.assert_not_called() + + def test_add_or_renew_pool_tags_pool_with_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool) as host_connection: + future = session.add_or_renew_pool( + host, is_host_addition=False) + + executor.run_next() + + assert future.result() is True + host_connection.assert_called_once_with( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + assert created_pools[0].endpoint == old_endpoint + created_pools[0].shutdown.assert_not_called() + + def test_add_or_renew_pool_auth_failure_reports_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + def fail_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + raise AuthenticationFailed("failed") + + with patch("cassandra.cluster.HostConnection", side_effect=fail_pool): + future = session.add_or_renew_pool(host, is_host_addition=False) + executor.run_next() + + assert future.result() is False + args, kwargs = cluster.signal_connection_failure.call_args + conn_exc = args[1] + assert isinstance(conn_exc, ConnectionException) + assert conn_exc.endpoint == old_endpoint + assert isinstance(conn_exc.__cause__, AuthenticationFailed) + assert args == (host, conn_exc, False) + assert kwargs == {"expected_endpoint": old_endpoint} + + def test_add_or_renew_pool_failure_reports_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + pool_error = RuntimeError("failed") + + def fail_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + raise pool_error + + with patch("cassandra.cluster.HostConnection", side_effect=fail_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + executor.run_next() + + assert future.result() is False + args, kwargs = cluster.signal_connection_failure.call_args + conn_exc = args[1] + assert conn_exc is pool_error + assert conn_exc.endpoint == old_endpoint + assert args == (host, conn_exc, True) + assert kwargs == { + "expect_host_to_be_down": True, + "expected_endpoint": old_endpoint + } + + def test_removed_in_flight_pool_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + with patch("cassandra.cluster.HostConnection", + side_effect=ConnectionException("failed")): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + cluster.signal_connection_failure.assert_not_called() + + def test_removed_in_flight_pool_auth_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + with patch("cassandra.cluster.HostConnection", + side_effect=AuthenticationFailed("failed")): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + cluster.signal_connection_failure.assert_not_called() + + def test_stale_keyspace_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.connect_timeout = 1 + cluster.on_down = Mock() + session.keyspace = "ks" + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + pool._keyspace = None + + def set_keyspace(keyspace, callback): + session.remove_pool(host) + callback(pool, [Exception("failed")]) + + pool._set_keyspace_for_all_conns.side_effect = set_keyspace + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + cluster.on_down.assert_not_called() + + def test_keyspace_failure_signals_down_for_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.connect_timeout = 1 + cluster.on_down = Mock() + session.keyspace = "ks" + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + pool._keyspace = None + created_pools.append(pool) + + def set_keyspace(keyspace, callback): + host.endpoint = DefaultEndPoint("127.0.0.2") + callback(pool, [Exception("failed")]) + + pool._set_keyspace_for_all_conns.side_effect = set_keyspace + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + cluster.on_down.assert_called_once_with( + host, True, expected_endpoint=old_endpoint) + created_pools[0].shutdown.assert_called_once_with() + + class HostStateRaceTest(unittest.TestCase): @staticmethod @@ -357,6 +845,29 @@ def test_unreserved_down_handling_is_ignored_during_host_up_handling(self): session.on_down.assert_not_called() cluster._start_reconnector.assert_not_called() + def test_reserved_down_handling_after_endpoint_swap_only_removes_stale_pool(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + down_epoch = self._reserve_down_handling(cluster, host) + + host.endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch, + expected_endpoint=old_endpoint) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + def test_noop_down_during_up_handling_does_not_supersede_up(self): pool_future = Future() session = Mock() @@ -448,6 +959,67 @@ def test_stale_failed_up_callback_does_not_cleanup_newer_down(self): assert not host.is_up assert self._state(cluster, host).up_epoch is None + def test_failed_up_callback_rechecks_before_cleanup(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + + def force_down_before_cleanup(message, *args, **kwargs): + if message.startswith("Connection pool could not be created"): + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True) + + with patch("cassandra.cluster.log.debug", side_effect=force_down_before_cleanup): + pool_future.set_result(False) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY) + assert session.remove_pool.call_count == 2 + listener.on_up.assert_not_called() + assert not host.is_up + assert state.up_epoch is None + assert state.down_epoch is None + + def test_failed_up_callback_after_endpoint_swap_does_not_signal_down(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + old_endpoint = host.endpoint + + host.endpoint = DefaultEndPoint("127.0.0.2") + pool_future.set_exception(RuntimeError("pool failed after endpoint swap")) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert session.remove_pool.call_count == 2 + session.remove_pool.assert_any_call( + host, expected_host=host, expected_endpoint=old_endpoint) + assert not host.is_up + assert state.up_epoch is None + def test_forced_down_during_up_handling_is_not_hidden_by_reconnector(self): session = Mock() listener = Mock() @@ -480,6 +1052,24 @@ def force_down_before_reconnector_is_cleared(h, up_epoch): assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).down_epoch is None + def test_forced_down_while_reconnecting_runs_new_down_handling(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_down() + host._reconnection_handler = Mock() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert self._state(cluster, host).down_epoch is None + def test_newer_down_before_up_side_effects_suppresses_stale_up(self): cluster = self._make_cluster() cluster.profile_manager.distance.return_value = HostDistance.IGNORED @@ -1105,6 +1695,55 @@ def test_wrapped_auth_failure_for_unknown_host_does_not_start_down_handling(self cluster.control_connection.on_down.assert_not_called() cluster._start_reconnector.assert_not_called() + def test_connection_failure_after_endpoint_swap_is_ignored(self): + cluster = self._make_cluster() + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + host.endpoint = DefaultEndPoint("127.0.0.2") + + is_down = cluster.signal_connection_failure( + host, ConnectionException("failed", endpoint=old_endpoint), + is_host_addition=False, expected_endpoint=old_endpoint) + + assert not is_down + assert host.is_up + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_auth_failed_up_pool_for_unknown_host_rolls_back_without_reconnector(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + auth_exc = AuthenticationFailed("bad credentials") + conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc.__cause__ = auth_exc + + Cluster.on_up(cluster, host) + is_down = cluster.signal_connection_failure( + host, conn_exc, is_host_addition=False) + + assert is_down + assert host.is_up is None + session.remove_pool.reset_mock() + + pool_future.set_result(False) + + assert host.is_up is None + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=host.endpoint) + listener.on_up.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).up_epoch is None + def test_real_down_for_unknown_host_marks_host_down(self): cluster = self._make_cluster() host = self._make_host() From a4b8a25ce421fcde4f09e85c5eeafed370b07a14 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 2 May 2026 14:55:58 -0400 Subject: [PATCH 04/29] cluster: fence stale host events by endpoint Track the endpoint tied to host up/down handling, reconnection callbacks, and pool cleanup so stale work from a previous endpoint cannot mark or reconnect a replacement host. Preserve endpoint-specific identity for SNI and client-routes endpoints, scope non-retryable auth failures to the matching endpoint, and remove stale pools by host identity instead of endpoint equality. Add unit coverage for endpoint swaps, queued up/down races, stale reconnector success, and defunct connection handling after client-route port changes. --- cassandra/cluster.py | 654 ++++++++++++++++++------ cassandra/pool.py | 392 +++++++++----- tests/unit/test_cluster.py | 164 +++++- tests/unit/test_control_connection.py | 43 +- tests/unit/test_host_connection_pool.py | 67 ++- 5 files changed, 1032 insertions(+), 288 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 51651679a1..fafc741e39 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -25,6 +25,8 @@ from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy +from contextlib import contextmanager +from contextvars import ContextVar from functools import partial, reduce, wraps from itertools import groupby, count, chain import json @@ -194,6 +196,37 @@ def _connection_reduce_fn(val,import_fn): _GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') _NOT_SET = object() +_NON_RETRYABLE_AUTH_FAILURE_ATTR = "_cassandra_non_retryable_auth_failure" +_POOL_CLEANUP_EPOCH = ContextVar("_POOL_CLEANUP_EPOCH", default=None) + + +def _make_connection_kwargs(endpoint, kwargs_dict, auth_provider_callable, port, + compression, sockopts, ssl_options, ssl_context, + cql_version, protocol_version, user_type_map, + allow_beta_protocol_version, no_compact, + application_info): + if auth_provider_callable: + kwargs_dict.setdefault('authenticator', auth_provider_callable(endpoint.address)) + + kwargs_dict.setdefault('port', port) + kwargs_dict.setdefault('compression', compression) + kwargs_dict.setdefault('sockopts', sockopts) + kwargs_dict.setdefault('ssl_options', ssl_options) + kwargs_dict.setdefault('ssl_context', ssl_context) + kwargs_dict.setdefault('cql_version', cql_version) + kwargs_dict.setdefault('protocol_version', protocol_version) + kwargs_dict.setdefault('user_type_map', user_type_map) + kwargs_dict.setdefault('allow_beta_protocol_version', allow_beta_protocol_version) + kwargs_dict.setdefault('no_compact', no_compact) + kwargs_dict.setdefault('application_info', application_info) + + return kwargs_dict + + +def _connection_factory_for_endpoint(endpoint, connect_timeout, connection_factory, + args, kwargs, connection_options): + kwargs = _make_connection_kwargs(endpoint, kwargs.copy(), *connection_options) + return connection_factory(endpoint, connect_timeout, *args, **kwargs) class NoHostAvailable(Exception): @@ -301,7 +334,12 @@ class _HostLivenessState(_EventFenceState): _DOWN = "down" _PENDING_UP = "pending_up" - __slots__ = () + __slots__ = ("up_endpoint", "pending_up_endpoint") + + def __init__(self): + _EventFenceState.__init__(self) + self.up_endpoint = None + self.pending_up_endpoint = None @property def up_epoch(self): @@ -1846,26 +1884,23 @@ def connection_factory(self, endpoint, host_conn = None, *args, **kwargs): return self.connection_class.factory(endpoint, self.connect_timeout, host_conn, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): - kwargs = self._make_connection_kwargs(host.endpoint, kwargs) - return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) + connection_options = ( + self._auth_provider_callable, self.port, self.compression, + self.sockopts, self.ssl_options, self.ssl_context, + self.cql_version, self.protocol_version, self._user_types, + self.allow_beta_protocol_version, self.no_compact, + self.application_info) + return partial(_connection_factory_for_endpoint, host.endpoint, + self.connect_timeout, self.connection_class.factory, + args, kwargs.copy(), connection_options) def _make_connection_kwargs(self, endpoint, kwargs_dict): - if self._auth_provider_callable: - kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) - - kwargs_dict.setdefault('port', self.port) - kwargs_dict.setdefault('compression', self.compression) - kwargs_dict.setdefault('sockopts', self.sockopts) - kwargs_dict.setdefault('ssl_options', self.ssl_options) - kwargs_dict.setdefault('ssl_context', self.ssl_context) - kwargs_dict.setdefault('cql_version', self.cql_version) - kwargs_dict.setdefault('protocol_version', self.protocol_version) - kwargs_dict.setdefault('user_type_map', self._user_types) - kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) - kwargs_dict.setdefault('no_compact', self.no_compact) - kwargs_dict.setdefault('application_info', self.application_info) - - return kwargs_dict + return _make_connection_kwargs( + endpoint, kwargs_dict, self._auth_provider_callable, self.port, + self.compression, self.sockopts, self.ssl_options, self.ssl_context, + self.cql_version, self.protocol_version, self._user_types, + self.allow_beta_protocol_version, self.no_compact, + self.application_info) def protocol_downgrade(self, host_endpoint, previous_version): if self._protocol_version_explicit: @@ -2014,14 +2049,21 @@ def _session_register_user_types(self, session): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) - def _cleanup_failed_on_up_handling(self, host, start_reconnector=True, expected_endpoint=None): + def _cleanup_failed_on_up_handling(self, host, start_reconnector=True, + expected_endpoint=None, expected_epoch=None): + if expected_epoch is not None: + with host.lock: + if self._get_host_liveness_state(host).epoch != expected_epoch: + log.debug("Ignoring stale failed up cleanup for node %s", host) + return + endpoint_changed = False if expected_endpoint is None: self.profile_manager.on_down(host) self.control_connection.on_down(host) else: with host.lock: - endpoint_changed = host.endpoint != expected_endpoint + endpoint_changed = not self._endpoints_match(host.endpoint, expected_endpoint) if endpoint_changed: log.debug("Not signalling down for stale up handling on node %s; endpoint changed from %s", host, expected_endpoint) @@ -2029,10 +2071,19 @@ def _cleanup_failed_on_up_handling(self, host, start_reconnector=True, expected_ self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): - session.remove_pool( - host, expected_host=host, expected_endpoint=expected_endpoint) + with self._pool_cleanup_epoch(host, expected_epoch): + future = session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) + if future: + future.add_done_callback( + lambda f, session=session: session.update_created_pools()) + + cleanup_is_current = True + if expected_epoch is not None: + with host.lock: + cleanup_is_current = self._get_host_liveness_state(host).epoch == expected_epoch - if start_reconnector and not endpoint_changed: + if start_reconnector and not endpoint_changed and cleanup_is_current: self._start_reconnector( host, is_host_addition=False, expected_endpoint=expected_endpoint) @@ -2043,9 +2094,27 @@ def _get_host_liveness_state(self, host): fences = self._host_liveness = _EventFenceMap(_HostLivenessState) return fences.get_state(host) + @contextmanager + def _pool_cleanup_epoch(self, host, expected_epoch=None, + allow_current_down=False): + if expected_epoch is None: + yield + return + + token = _POOL_CLEANUP_EPOCH.set( + (host, expected_epoch, allow_current_down)) + try: + yield + finally: + _POOL_CLEANUP_EPOCH.reset(token) + def _up_handling_was_superseded(self, host, up_epoch): state = self._get_host_liveness_state(host) - return not state.event_is_current(_HostLivenessState._UP, up_epoch) + if not state.event_is_current(_HostLivenessState._UP, up_epoch): + return True + + expected_endpoint = state.up_endpoint + return expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint) def _up_handling_is_superseded(self, host, up_epoch): with host.lock: @@ -2054,9 +2123,17 @@ def _up_handling_is_superseded(self, host, up_epoch): log.debug("Ignoring superseded up handling for node %s", host) return superseded - def _get_reconnector_for_current_up_handling(self, host, up_epoch): + def _get_reconnector_for_current_up_handling(self, host, up_epoch, + expected_reconnector=_NOT_SET): with host.lock: if self._up_handling_was_superseded(host, up_epoch): + reconnector = host._reconnection_handler + # Only cancel the handler observed when up handling started. + # A newer down event may already have installed its own reconnector. + if expected_reconnector is _NOT_SET or reconnector is expected_reconnector: + host._reconnection_handler = None + if reconnector: + reconnector.cancel() log.debug("Ignoring superseded up handling for node %s", host) return None, True reconnector = host._reconnection_handler @@ -2065,32 +2142,54 @@ def _get_reconnector_for_current_up_handling(self, host, up_epoch): def _clear_up_handling(self, host, up_epoch=None): state = self._get_host_liveness_state(host) - return state.clear_event(_HostLivenessState._UP, up_epoch) + if state.clear_event(_HostLivenessState._UP, up_epoch): + state.up_endpoint = None + return True + return False + + def _cleanup_superseded_up_handling(self, host, expected_endpoint=None, expected_epoch=None): + if expected_epoch is not None: + with host.lock: + if self._get_host_liveness_state(host).epoch != expected_epoch: + log.debug("Ignoring stale superseded up cleanup for node %s", host) + return - def _cleanup_superseded_up_handling(self, host, expected_endpoint=None): for session in tuple(self.sessions): - session.remove_pool( - host, expected_host=host, expected_endpoint=expected_endpoint) + with self._pool_cleanup_epoch(host, expected_epoch): + future = session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) + if future: + future.add_done_callback( + lambda f, session=session: session.update_created_pools()) def _pop_pending_node_up_if_ready(self, host): state = self._get_host_liveness_state(host) if state.pending_up_epoch is None: + state.pending_up_endpoint = None return None if host.is_up: state.pending_up_epoch = None + state.pending_up_endpoint = None return None if state.up_epoch is not None or state.down_epoch is not None: return None pending_up_epoch = state.pending_up_epoch + pending_up_endpoint = state.pending_up_endpoint # Leave the pending marker in place until on_up() reacquires host.lock so # a newer down signal can still invalidate this replay. - return pending_up_epoch + return pending_up_epoch, pending_up_endpoint - def _handle_pending_node_up(self, host, pending_up_epoch): - if pending_up_epoch is not None: + def _handle_pending_node_up(self, host, pending_up): + if pending_up is not None: + if isinstance(pending_up, tuple): + pending_up_epoch, pending_up_endpoint = pending_up + else: + pending_up_epoch = pending_up + pending_up_endpoint = None log.debug("Handling queued up status of node %s", host) - self._on_up(host, expected_epoch=pending_up_epoch) + self._on_up(host, expected_epoch=pending_up_epoch, + expected_endpoint=pending_up_endpoint) def _clear_down_handling(self, host, down_epoch=None): state = self._get_host_liveness_state(host) @@ -2098,7 +2197,7 @@ def _clear_down_handling(self, host, down_epoch=None): def _finish_superseded_up_handling(self, host, up_epoch, expected_endpoint=None): self._cleanup_superseded_up_handling( - host, expected_endpoint=expected_endpoint) + host, expected_endpoint=expected_endpoint, expected_epoch=up_epoch) pending_up_epoch = None with host.lock: @@ -2139,19 +2238,22 @@ def _on_up_future_completed(self, host, up_handling_revision, up_handling_endpoi host, up_handling_revision, expected_endpoint=up_handling_endpoint): return self._cleanup_failed_on_up_handling( - host, expected_endpoint=up_handling_endpoint) + host, expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) with host.lock: - start_reconnector = host.is_up is not None + # Only suppress retries for the auth-failure quarantine. + start_reconnector = not self._has_non_retryable_auth_failure(host) if self._finish_up_if_superseded( host, up_handling_revision, expected_endpoint=up_handling_endpoint): return self._cleanup_failed_on_up_handling( host, start_reconnector=start_reconnector, - expected_endpoint=up_handling_endpoint) + expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) return log.info("Connection pools established for node %s", host) @@ -2180,12 +2282,15 @@ def _on_up_future_completed(self, host, up_handling_revision, up_handling_endpoi # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() + self._set_non_retryable_auth_failure(host, False) return - def on_up(self, host): - return self._on_up(host) + def on_up(self, host, expected_endpoint=None, expected_reconnector=None): + return self._on_up(host, expected_endpoint=expected_endpoint, + expected_reconnector=expected_reconnector) - def _on_up(self, host, expected_epoch=None): + def _on_up(self, host, expected_epoch=None, expected_endpoint=None, + expected_reconnector=None): """ Intended for internal use only. """ @@ -2193,16 +2298,31 @@ def _on_up(self, host, expected_epoch=None): return log.debug("Waiting to acquire lock for handling up status of node %s", host) + + def _clear_stale_reconnector(): + # Stale callbacks can outlive the reconnector callback path. + if (expected_reconnector is not None and + host._reconnection_handler is expected_reconnector): + host._reconnection_handler = None + with host.lock: state = self._get_host_liveness_state(host) + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Ignoring stale queued up handling for node %s; endpoint changed from %s", + host, expected_endpoint) + _clear_stale_reconnector() + return + if (expected_epoch is not None and (state.epoch != expected_epoch or state.pending_up_epoch != expected_epoch)): log.debug("Ignoring stale queued up handling for node %s", host) + _clear_stale_reconnector() return if state.down_epoch is not None: log.debug("Down status is being handled for node %s; queueing up handling", host) state.pending_up_epoch = state.epoch + state.pending_up_endpoint = expected_endpoint return if state.up_epoch is not None: @@ -2210,7 +2330,12 @@ def _on_up(self, host, expected_epoch=None): if self._up_handling_was_superseded(host, up_handling_revision): log.debug("Superseded up handling is still finishing for node %s; " "queueing up handling", host) + # Endpoint swap without an epoch bump needs a fresh epoch so + # the replayed up handling is not cleared by the stale finish. + if state.epoch == up_handling_revision: + state.advance() state.pending_up_epoch = state.epoch + state.pending_up_endpoint = expected_endpoint else: log.debug("Another thread is already handling up status of node %s", host) return @@ -2218,12 +2343,16 @@ def _on_up(self, host, expected_epoch=None): if host.is_up: log.debug("Host %s was already marked up", host) state.pending_up_epoch = None + state.pending_up_endpoint = None return state.pending_up_epoch = None + state.pending_up_endpoint = None up_handling_revision = state.epoch - up_handling_endpoint = host.endpoint + up_handling_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint state.up_epoch = up_handling_revision + state.up_endpoint = up_handling_endpoint + up_handling_reconnector = host._reconnection_handler log.debug("Starting to handle up status of node %s", host) have_future = False @@ -2232,7 +2361,7 @@ def _on_up(self, host, expected_epoch=None): log.info("Host %s may be up; will prepare queries and open connection pool", host) reconnector, superseded = self._get_reconnector_for_current_up_handling( - host, up_handling_revision) + host, up_handling_revision, expected_reconnector=up_handling_reconnector) if superseded: self._finish_superseded_up_handling( host, up_handling_revision, expected_endpoint=up_handling_endpoint) @@ -2249,8 +2378,9 @@ def _on_up(self, host, expected_epoch=None): log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): - session.remove_pool( - host, expected_host=host, expected_endpoint=up_handling_endpoint) + with self._pool_cleanup_epoch(host, expected_epoch=up_handling_revision): + session.remove_pool( + host, expected_host=host, expected_endpoint=up_handling_endpoint) if self._finish_up_if_superseded( host, up_handling_revision, expected_endpoint=up_handling_endpoint): @@ -2277,7 +2407,8 @@ def _on_up(self, host, expected_epoch=None): up_handling_endpoint, futures, futures_results, futures_lock) for session in tuple(self.sessions): - future = session.add_or_renew_pool(host, is_host_addition=False) + future = session.add_or_renew_pool( + host, is_host_addition=False, allow_retry_after_auth_failure=True) if future is not None: have_future = True future.add_done_callback(callback) @@ -2288,7 +2419,8 @@ def _on_up(self, host, expected_epoch=None): future.cancel() self._cleanup_failed_on_up_handling( - host, expected_endpoint=up_handling_endpoint) + host, expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) pending_up_epoch = None with host.lock: @@ -2309,6 +2441,15 @@ def _on_up(self, host, expected_epoch=None): if superseded: self._finish_superseded_up_handling( host, up_handling_revision, expected_endpoint=up_handling_endpoint) + return futures + + for listener in self.listeners: + listener.on_up(host) + + for session in tuple(self.sessions): + session.update_created_pools() + + self._set_non_retryable_auth_failure(host, False) # for testing purposes return futures @@ -2321,7 +2462,7 @@ def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None, schedule = self.reconnection_policy.new_schedule() with host.lock: - if expected_endpoint is not None and host.endpoint != expected_endpoint: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Not starting reconnector for host %s; endpoint changed from %s", host, expected_endpoint) return @@ -2337,7 +2478,7 @@ def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None, log.debug("Not starting reconnector for host %s; down handling is no longer current", host) return - if expected_endpoint is not None and host.endpoint != expected_endpoint: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Not starting reconnector for host %s; endpoint changed from %s", host, expected_endpoint) return @@ -2363,6 +2504,7 @@ def on_down_potentially_blocking( expected_endpoint: Optional[EndPoint] = None) -> Any: pending_up_epoch = None try: + down_endpoint = None with host.lock: state = self._get_host_liveness_state(host) owns_reserved_down_handling = down_epoch is not None and state.down_epoch == down_epoch @@ -2375,30 +2517,31 @@ def on_down_potentially_blocking( elif not owns_reserved_down_handling: log.debug("Ignoring stale down handling for host %s", host) return - endpoint_matches = expected_endpoint is None or host.endpoint == expected_endpoint + down_endpoint = host.endpoint + endpoint_matches = expected_endpoint is None or self._endpoints_match( + down_endpoint, expected_endpoint) if endpoint_matches: - if expected_endpoint is None: - self.profile_manager.on_down(host) - self.control_connection.on_down(host) - else: - with host.lock: - if host.endpoint != expected_endpoint: - log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", - host, expected_endpoint) - endpoint_matches = False - else: - self.profile_manager.on_down(host) - self.control_connection.on_down(host) + with host.lock: + if not self._endpoints_match(host.endpoint, down_endpoint): + log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", + host, down_endpoint) + endpoint_matches = False + else: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) else: log.debug("Not signalling down for stale down handling on node %s; endpoint changed from %s", host, expected_endpoint) for session in tuple(self.sessions): - if expected_endpoint is None: - session.on_down(host) - else: - session.on_down(host, expected_endpoint=expected_endpoint) + with self._pool_cleanup_epoch( + host, expected_epoch=down_epoch, + allow_current_down=True): + if expected_endpoint is None: + session.on_down(host, expected_endpoint=down_endpoint) + else: + session.on_down(host, expected_endpoint=expected_endpoint) if endpoint_matches: if expected_endpoint is None: @@ -2406,7 +2549,7 @@ def on_down_potentially_blocking( listener.on_down(host) else: with host.lock: - if host.endpoint != expected_endpoint: + if not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", host, expected_endpoint) endpoint_matches = False @@ -2418,7 +2561,7 @@ def on_down_potentially_blocking( start_reconnector = (endpoint_matches and self._get_host_liveness_state(host).down_epoch == down_epoch) if (start_reconnector and expected_endpoint is not None and - host.endpoint != expected_endpoint): + not self._endpoints_match(host.endpoint, expected_endpoint)): log.debug("Not starting reconnector for host %s; endpoint changed from %s", host, expected_endpoint) start_reconnector = False @@ -2448,7 +2591,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, return with host.lock: - if expected_endpoint is not None and host.endpoint != expected_endpoint: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Ignoring stale down signal for host %s; endpoint changed from %s", host, expected_endpoint) return @@ -2458,13 +2601,19 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated - if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: + # endpoint-aware cleanup still needs to run + if (self._discount_down_events and expected_endpoint is None and + self.profile_manager.distance(host) != HostDistance.IGNORED): + host_endpoint = host.endpoint connected = False for session in tuple(self.sessions): - pool_states = session.get_pool_state() - pool_state = pool_states.get(host) - if pool_state: - connected |= pool_state['open_count'] > 0 + # Host equality is endpoint-based; scan by identity to avoid + # hiding the live pool behind a stale equal key. + pool = session._get_pool_by_host_identity( + host, expected_endpoint=host_endpoint) + if pool is not None and pool.open_count > 0: + connected = True + break if connected: return @@ -2473,11 +2622,13 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, if state.pending_up_epoch is not None: state.advance() state.pending_up_epoch = None + state.pending_up_endpoint = None host.set_down() return state.advance() state.pending_up_epoch = None + state.pending_up_endpoint = None host.set_down() down_epoch = state.epoch if state.down_epoch is not None: @@ -2568,6 +2719,9 @@ def _finalize_add(self, host, set_up=True): for session in tuple(self.sessions): session.update_created_pools() + if set_up: + self._set_non_retryable_auth_failure(host, False) + def on_remove(self, host): if self.is_shutdown: return @@ -2577,8 +2731,12 @@ def on_remove(self, host): state = self._get_host_liveness_state(host) state.advance() state.pending_up_epoch = None + state.pending_up_endpoint = None + state.up_epoch = None + state.up_endpoint = None state.down_epoch = None host.set_down() + self._set_non_retryable_auth_failure(host, False) self.profile_manager.on_remove(host) for session in tuple(self.sessions): session.on_remove(host) @@ -2595,22 +2753,91 @@ def _is_authentication_failure(connection_exc): return (isinstance(connection_exc, AuthenticationFailed) or isinstance(getattr(connection_exc, "__cause__", None), AuthenticationFailed)) + @staticmethod + def _endpoint_address_key(endpoint): + if endpoint is None: + return None + if isinstance(endpoint, tuple): + return endpoint[:2] + address = getattr(endpoint, "address", None) + port = getattr(endpoint, "port", None) + if address is not None or port is not None: + return address, port + return endpoint + + @classmethod + def _endpoint_key(cls, endpoint): + if endpoint is None: + return None + if isinstance(endpoint, tuple): + return endpoint[:2] + address_key = cls._endpoint_address_key(endpoint) + if not isinstance(address_key, tuple): + return endpoint + if endpoint.__class__ is DefaultEndPoint: + return address_key + # Non-default endpoint classes can carry connection identity that + # address/port alone does not preserve, e.g. SNI or client-routes IDs. + endpoint_key = (endpoint.__class__,) + address_key + server_name = getattr(endpoint, "_server_name", None) + if server_name is not None: + return endpoint_key + (server_name,) + host_id = getattr(endpoint, "host_id", None) + if host_id is not None: + return endpoint_key + (host_id,) + if isinstance(endpoint, EndPoint): + return endpoint_key + (endpoint,) + return endpoint + + @classmethod + def _endpoints_match(cls, endpoint, expected_endpoint): + if isinstance(endpoint, tuple) or isinstance(expected_endpoint, tuple): + return (cls._endpoint_address_key(endpoint) == + cls._endpoint_address_key(expected_endpoint)) + return cls._endpoint_key(endpoint) == cls._endpoint_key(expected_endpoint) + + @classmethod + def _set_non_retryable_auth_failure(cls, host, enabled): + if enabled: + setattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR, cls._endpoint_key(host.endpoint)) + elif hasattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR): + delattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR) + + @classmethod + def _has_non_retryable_auth_failure(cls, host): + failure_endpoint = getattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR, None) + return failure_endpoint is not None and failure_endpoint == cls._endpoint_key(host.endpoint) + def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False, expected_endpoint=None): with host.lock: - if expected_endpoint is not None and host.endpoint != expected_endpoint: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", host, expected_endpoint) return False + if expected_endpoint is not None: + if isinstance(expected_endpoint, tuple): + current_host = self.metadata.get_host(*expected_endpoint[:2]) + else: + current_host = self.metadata.get_host(expected_endpoint) + if current_host is not None and current_host is not host: + log.debug("Ignoring stale connection failure for host %s; endpoint reassigned to %s", + host, current_host) + return False + + is_auth_failure = self._is_authentication_failure(connection_exc) is_down = host.signal_connection_failure(connection_exc) + if host.is_up is None and is_auth_failure: + # Never-up auth failures are terminal for the endpoint even if conviction policy + # has not yet decided the host is down. Do not run normal down handling + # for a host that was never marked up. + self._set_non_retryable_auth_failure(host, True) + return is_down if is_down: - if host.is_up is None and self._is_authentication_failure(connection_exc): - return is_down - if is_down: - self.on_down( - host, is_host_addition, expect_host_to_be_down, - expected_endpoint=expected_endpoint) + self.on_down( + host, is_host_addition, expect_host_to_be_down, + expected_endpoint=expected_endpoint) return is_down def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): @@ -3716,6 +3943,18 @@ def _get_pool_creation_state(self, host): fences = self._pool_creation_fences = _EventFenceMap(_PoolCreationState) return fences.get_state(host) + def _endpoints_match(self, endpoint, expected_endpoint): + matches = self.cluster._endpoints_match(endpoint, expected_endpoint) + if isinstance(matches, bool): + return matches + return endpoint == expected_endpoint + + def _has_non_retryable_auth_failure(self, host): + return self.cluster._has_non_retryable_auth_failure(host) + + def _get_host_liveness_state(self, host): + return self.cluster._get_host_liveness_state(host) + def _pool_creation_is_current(self, host, creation_epoch): state = self._get_pool_creation_state(host) return state.event_is_current(_PoolCreationState._CREATE, creation_epoch) @@ -3731,7 +3970,7 @@ def _clear_pool_creation(self, host, creation_epoch): def _invalidate_pool_creation(self, host, expected_endpoint=None): state = self._get_pool_creation_state(host) if state.creation_epoch is not None: - if expected_endpoint is not None and state.endpoint != expected_endpoint: + if expected_endpoint is not None and not self._endpoints_match(state.endpoint, expected_endpoint): return False state.advance() state.creation_epoch = None @@ -3740,13 +3979,18 @@ def _invalidate_pool_creation(self, host, expected_endpoint=None): return True return False - def add_or_renew_pool(self, host, is_host_addition): + def add_or_renew_pool(self, host, is_host_addition, allow_retry_after_auth_failure=False): """ For internal use only. """ distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None + if (not is_host_addition and not allow_retry_after_auth_failure and + self._has_non_retryable_auth_failure(host)): + # Auth failure quarantine is endpoint-scoped; same endpoint survives later + # down events until a successful add/up clears it. + return None creation_epoch = None creation_endpoint = None @@ -3790,9 +4034,10 @@ def run_add_or_renew_pool(): pool_endpoint = self._get_pool_endpoint(new_pool, creation_endpoint) new_pool.endpoint = pool_endpoint - previous = None + previous_pools = [] discard_pool = False signal_down = False + reuse_existing_pool = False with self._lock: while new_pool._keyspace != self.keyspace: if not self._pool_creation_is_current(host, creation_epoch): @@ -3820,12 +4065,58 @@ def callback(pool, errors): break if not discard_pool and not signal_down: - if self._pool_creation_is_current(host, creation_epoch): - previous = self._pools.get(host) - self._pools[host] = new_pool - self._clear_pool_creation(host, creation_epoch) - else: + if not self._pool_creation_is_current(host, creation_epoch): discard_pool = True + else: + with host.lock: + endpoint_changed = not self._endpoints_match(host.endpoint, creation_endpoint) + if endpoint_changed: + log.debug( + "Discarding stale connection pool for host %s; endpoint changed from %s", + host, creation_endpoint) + self._invalidate_pool_creation(host, expected_endpoint=creation_endpoint) + discard_pool = True + else: + # Rebuild by identity so endpoint hash changes do not + # leave stale pool entries behind. + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if pool_host is host: + previous_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + + # Keep the current metadata host keyed by identity. + metadata_host = host + if isinstance(self.cluster.metadata, Metadata): + metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) + + target_host = metadata_host if metadata_host is not None else host + target_host_matches = False + for pool_host in tuple(retained_pools): + if pool_host is target_host: + target_host_matches = True + elif pool_host == target_host: + previous_pools.append(retained_pools.pop(pool_host)) + + if target_host_matches: + reuse_existing_pool = True + else: + source_host = new_pool.host + if (source_host is not target_host and + target_host.sharding_info is None): + target_host.sharding_info = source_host.sharding_info + new_pool.host = target_host + retained_pools[target_host] = new_pool + self._pools = retained_pools + self._clear_pool_creation(host, creation_epoch) + + if reuse_existing_pool: + log.debug("Reusing existing connection pool for host %s", host) + new_pool.shutdown() + for previous in previous_pools: + previous.shutdown() + return True if signal_down: self.cluster.on_down( @@ -3839,7 +4130,7 @@ def callback(pool, errors): return False log.debug("Added pool for host %s to session", host) - if previous: + for previous in previous_pools: previous.shutdown() return True @@ -3864,40 +4155,73 @@ def callback(pool, errors): def remove_pool(self, host, expected_host=None, expected_endpoint=None): removed_pools = [] + cleanup_context = _POOL_CLEANUP_EPOCH.get() with self._lock: - pool = self._get_pool_by_host_identity( - host, expected_host=expected_host, expected_endpoint=expected_endpoint) - remove_by_identity = pool is not None - if pool is None: - pool = self._pools.get(host) - if pool is not None and not self._pool_matches_expected( - pool, expected_host=expected_host, expected_endpoint=expected_endpoint): - self._invalidate_pool_creation( - host, expected_endpoint=expected_endpoint) - return None - self._invalidate_pool_creation( - host, expected_endpoint=expected_endpoint) - if pool is not None: - if remove_by_identity: - retained_pools = {} - for pool_host, host_pool in self._pools.items(): - if (pool_host is host and self._pool_matches_expected( - host_pool, expected_host=expected_host, - expected_endpoint=expected_endpoint)): - removed_pools.append(host_pool) - else: - retained_pools[pool_host] = host_pool - self._pools = retained_pools + with host.lock: + if cleanup_context is not None: + cleanup_host, cleanup_epoch = cleanup_context[:2] + allow_current_down = ( + len(cleanup_context) > 2 and cleanup_context[2]) + if cleanup_host is host: + # Stale cleanup can outlive an A->B->A flip; preserve the + # pool that belongs to the current epoch. + state = self._get_host_liveness_state(host) + cleanup_still_owns_down = ( + allow_current_down and + state.down_epoch == cleanup_epoch) + if (state.epoch != cleanup_epoch and + not cleanup_still_owns_down): + log.debug( + "Ignoring stale pool cleanup for host %s; epoch moved from %s to %s", + host, cleanup_epoch, state.epoch) + return None + + # Keep the endpoint snapshot stable until the removal pass + # finishes; migration can otherwise flip the branch mid-flight. + # Host hashes track endpoint, so endpoint swaps can hide stale + # entries behind normal dict lookups. Scan by identity instead. + remove_all = expected_endpoint is None + if expected_endpoint is not None: + remove_all = self._endpoints_match(host.endpoint, expected_endpoint) + + if remove_all: + self._invalidate_pool_creation(host) else: - removed_pool = self._pools.pop(host, None) - if removed_pool is not None: - removed_pools.append(removed_pool) + self._invalidate_pool_creation( + host, expected_endpoint=expected_endpoint) + + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if pool_host is not host: + retained_pools[pool_host] = host_pool + continue + + matches = self._pool_matches_expected( + host_pool, expected_host=expected_host, + expected_endpoint=None if remove_all else expected_endpoint) + if matches: + removed_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + self._pools = retained_pools if removed_pools: log.debug("Removed connection pool for %r", host) return self.submit(self._shutdown_removed_pools, removed_pools) else: return None + @contextmanager + def _pool_cleanup_epoch(self, host, expected_epoch=None): + if expected_epoch is None: + yield + return + + token = _POOL_CLEANUP_EPOCH.set((host, expected_epoch)) + try: + yield + finally: + _POOL_CLEANUP_EPOCH.reset(token) + @staticmethod def _shutdown_removed_pools(pools): for pool in pools: @@ -3910,7 +4234,7 @@ def _pool_matches_expected(self, pool, expected_host=None, expected_endpoint=Non pool_endpoint = getattr(pool, 'endpoint', None) if pool_endpoint is None: pool_endpoint = pool.host.endpoint - if pool_endpoint != expected_endpoint: + if not self._endpoints_match(pool_endpoint, expected_endpoint): return False return True @@ -3918,7 +4242,7 @@ def _pool_matches_expected(self, pool, expected_host=None, expected_endpoint=Non def _get_pool_endpoint(pool, default_endpoint): endpoint = getattr(pool, 'endpoint', None) connections = getattr(pool, '_connections', None) - if connections: + if isinstance(connections, Mapping): for connection in connections.values(): if getattr(connection, 'original_endpoint', None) is None: connection_endpoint = getattr(connection, 'endpoint', None) @@ -3929,10 +4253,16 @@ def _get_pool_endpoint(pool, default_endpoint): return endpoint if endpoint is not None else default_endpoint def _get_pool_by_host_identity(self, host, expected_host=None, expected_endpoint=None): - for pool_host, pool in self._pools.items(): - if (pool_host is host and self._pool_matches_expected( - pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): - return pool + pool = self._pools.get(host) + if (pool is not None and pool.host is host and self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): + return pool + + with self._lock: + for pool_host, pool in self._pools.items(): + if (pool_host is host and self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): + return pool return None def _reuse_or_invalidate_pool_creation(self, host, pool_creation_future): @@ -3942,7 +4272,7 @@ def _reuse_or_invalidate_pool_creation(self, host, pool_creation_future): if (pool_creation_state.creation_epoch is not None and pool_creation_state.future is pool_creation_future): with host.lock: - endpoint_changed = host.endpoint != pool_creation_state.endpoint + endpoint_changed = not self._endpoints_match(host.endpoint, pool_creation_state.endpoint) if endpoint_changed: self._invalidate_pool_creation( host, expected_endpoint=pool_creation_state.endpoint) @@ -3968,7 +4298,11 @@ def update_created_pools(self): for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) with self._lock: - pool = self._pools.get(host) + with host.lock: + host_endpoint = host.endpoint + pool = self._get_pool_by_host_identity( + host, expected_endpoint=host_endpoint) + any_pool = pool or self._get_pool_by_host_identity(host) pool_creation_state = self._get_pool_creation_state(host) pool_creation_future = ( pool_creation_state.future @@ -3988,6 +4322,8 @@ def update_created_pools(self): elif pool_creation_future is not None: future = self._reuse_or_invalidate_pool_creation( host, pool_creation_future) + elif any_pool is not None: + future = self.remove_pool(host) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: @@ -4698,13 +5034,15 @@ def _handle_status_change(self, event): change_type = event["change_type"] addr, port = event["address"] host = self._cluster.metadata.get_host(addr, port) + expected_endpoint = host.endpoint if host is not None else event["address"] if change_type == "UP": delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: - self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) + self._cluster.scheduler.schedule_unique( + delay, self._cluster.on_up, host, expected_endpoint=expected_endpoint) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. @@ -4712,7 +5050,7 @@ def _handle_status_change(self, event): # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler - self._cluster.on_down(host, is_host_addition=False) + self._cluster.on_down(host, is_host_addition=False, expected_endpoint=expected_endpoint) def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: """ @@ -4969,6 +5307,7 @@ class _Scheduler(Thread): def __init__(self, executor): self._queue = queue.PriorityQueue() self._scheduled_tasks = set() + self._scheduled_tasks_lock = Lock() self._count = count() self._executor = executor @@ -4983,26 +5322,54 @@ def shutdown(self): # this can happen on interpreter shutdown pass self.is_shutdown = True - self._queue.put_nowait((0, 0, None)) + self._queue.put_nowait((0, 0, None, None)) self.join() def schedule(self, delay, fn, *args, **kwargs): - self._insert_task(delay, (fn, args, tuple(kwargs.items()))) + task = (fn, args, tuple(kwargs.items())) + self._insert_task(delay, self._task_key(fn, args, kwargs), task) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) - if task not in self._scheduled_tasks: - self._insert_task(delay, task) - else: - log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + task_key = self._task_key(fn, args, kwargs) + self._insert_task(delay, task_key, task, unique=True) + + @staticmethod + def _freeze_task_arg(value): + if isinstance(value, Host): + return (Host, id(value)) + if isinstance(value, tuple): + return tuple(_Scheduler._freeze_task_arg(item) for item in value) + if isinstance(value, list): + return ('list', tuple(_Scheduler._freeze_task_arg(item) for item in value)) + if isinstance(value, Mapping): + return ( + 'dict', + tuple((key, _Scheduler._freeze_task_arg(val)) for key, val in value.items()) + ) + return value + + @classmethod + def _task_key(cls, fn, args, kwargs): + return ( + fn, + tuple(cls._freeze_task_arg(arg) for arg in args), + tuple((key, cls._freeze_task_arg(val)) for key, val in kwargs.items()), + ) + + def _insert_task(self, delay, task_key, task, unique=False): + with self._scheduled_tasks_lock: + if self.is_shutdown: + log.debug("Ignoring scheduled task after shutdown: %r", task) + return + + if unique and task_key in self._scheduled_tasks: + log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + return - def _insert_task(self, delay, task): - if not self.is_shutdown: run_at = time.time() + delay - self._scheduled_tasks.add(task) - self._queue.put_nowait((run_at, next(self._count), task)) - else: - log.debug("Ignoring scheduled task after shutdown: %r", task) + self._scheduled_tasks.add(task_key) + self._queue.put_nowait((run_at, next(self._count), task_key, task)) def run(self): while True: @@ -5011,19 +5378,20 @@ def run(self): try: while True: - run_at, i, task = self._queue.get(block=True, timeout=None) + run_at, i, task_key, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): - self._scheduled_tasks.discard(task) + with self._scheduled_tasks_lock: + self._scheduled_tasks.discard(task_key) fn, args, kwargs = task kwargs = dict(kwargs) future = self._executor.submit(fn, *args, **kwargs) future.add_done_callback(self._log_if_failed) else: - self._queue.put_nowait((run_at, i, task)) + self._queue.put_nowait((run_at, i, task_key, task)) break except queue.Empty: pass @@ -5204,7 +5572,7 @@ def _on_timeout(self, _attempts=0): # Capture connection stats before pool.return_connection() can alter state conn_in_flight = self._connection.in_flight - pool = self.session._pools.get(self._current_host) + pool = self.session._get_pool_by_host_identity(self._current_host) if pool and not pool.is_shutdown: # Do not return the stream ID to the pool yet. We cannot reuse it # because the node might still be processing the query and will @@ -5287,7 +5655,7 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message - pool = self.session._pools.get(host) + pool = self.session._get_pool_by_host_identity(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None diff --git a/cassandra/pool.py b/cassandra/pool.py index bdc1d07ea1..158a2f5445 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -16,7 +16,7 @@ Connection pooling and host management. """ from concurrent.futures import Future -from functools import total_ordering +from functools import partial, total_ordering import logging import socket import time @@ -45,6 +45,25 @@ class NoConnectionsAvailable(Exception): pass +def _current_host_endpoint(host): + host_lock = getattr(host, "lock", None) + if host_lock is not None: + try: + with host_lock: + return host.endpoint + except (AttributeError, TypeError): + pass + return host.endpoint + + +def _endpoints_match(cluster, endpoint, expected_endpoint): + matches = cluster._endpoints_match(endpoint, expected_endpoint) + if isinstance(matches, bool): + return matches + # Mocked clusters can return a bare Mock here; fall back to equality. + return endpoint == expected_endpoint + + @total_ordering class Host(object): """ @@ -221,13 +240,15 @@ def signal_connection_failure(self, connection_exc): def is_currently_reconnecting(self): return self._reconnection_handler is not None - def get_and_set_reconnection_handler(self, new_handler): + def get_and_set_reconnection_handler(self, new_handler, expected_handler=None): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ with self.lock: old = self._reconnection_handler + if expected_handler is not None and old is not expected_handler: + return None self._reconnection_handler = new_handler return old @@ -349,16 +370,24 @@ def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *a self.on_up = on_up self.host = host self.connection_factory = connection_factory + self.callback_kwargs.setdefault('expected_handler', self) def try_reconnect(self): - return self.connection_factory() + connection_factory = self.connection_factory + endpoint = _current_host_endpoint(self.host) + if isinstance(connection_factory, partial): + return connection_factory.func( + endpoint, *connection_factory.args[1:], + **(connection_factory.keywords or {})) + return connection_factory() def on_reconnection(self, connection): log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) if self.is_host_addition: self.on_add(self.host) else: - self.on_up(self.host) + self.on_up(self.host, expected_endpoint=getattr(connection, "endpoint", None), + expected_reconnector=self) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): @@ -551,18 +580,38 @@ def return_connection(self, connection, stream_was_orphaned=False): return is_down = False + stale_endpoint_failure = False if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.signal_connection_failure(connection.last_error) + with self.host.lock: + if not _endpoints_match(self._session.cluster, self.host.endpoint, self.endpoint): + log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", + self.host, self.endpoint) + stale_endpoint_failure = True + else: + is_down = self.host.signal_connection_failure(connection.last_error) connection.signaled_error = True + if stale_endpoint_failure: + # Drop only this stale pool; endpoint reuse may already belong to + # a replacement host instance. + future = self._session.remove_pool( + self.host, expected_host=self.host, expected_endpoint=self.endpoint) + if future: + future.add_done_callback(lambda f: self._session.update_created_pools()) + with self._lock: + self._connections.pop(connection.features.shard_id, None) + connection.close() + return + if self.shutdown_on_error and not is_down: is_down = True if is_down: self.shutdown() - self._session.cluster.on_down(self.host, is_host_addition=False) + self._session.cluster.on_down(self.host, is_host_addition=False, + expected_endpoint=self.endpoint) else: connection.close() with self._lock: @@ -594,7 +643,29 @@ def on_orphaned_stream_released(self): with self._stream_available_condition: self._stream_available_condition.notify() + def _remove_stale_pool(self, expected_endpoint): + future = self._session.remove_pool( + self.host, expected_host=self.host, expected_endpoint=expected_endpoint) + if future: + future.add_done_callback(lambda f: self._session.update_created_pools()) + def _replace(self, connection): + expected_endpoint = self.endpoint + current_endpoint = _current_host_endpoint(self.host) + if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + self._remove_stale_pool(expected_endpoint) + with self._lock: + self._is_replacing = False + with self._stream_available_condition: + self._stream_available_condition.notify() + return + + direct_reconnect = not (self.host.sharding_info and + not self._session.cluster.shard_aware_options.disable) + replacement_connection = None + keyspace = None with self._lock: if self.is_shutdown: return @@ -603,22 +674,48 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if direct_reconnect: + keyspace = self._keyspace + replacement_connection = self._session.cluster.connection_factory( + expected_endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + else: self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) - else: - connection = self._session.cluster.connection_factory(self.endpoint, - on_orphaned_stream_released=self.on_orphaned_stream_released) - if self._keyspace: - connection.set_keyspace_blocking(self._keyspace) - self._connections[connection.features.shard_id] = connection except Exception: - log.warning("Failed reconnecting %s. Retrying." % (self.endpoint,)) + log.warning("Failed reconnecting %s. Retrying." % (expected_endpoint,)) self._session.submit(self._replace, connection) - else: + return + + if not direct_reconnect: self._is_replacing = False with self._stream_available_condition: self._stream_available_condition.notify() + return + + current_endpoint = _current_host_endpoint(self.host) + if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + replacement_connection.close() + self._remove_stale_pool(expected_endpoint) + with self._lock: + self._is_replacing = False + with self._stream_available_condition: + self._stream_available_condition.notify() + return + + if keyspace: + replacement_connection.set_keyspace_blocking(keyspace) + + with self._lock: + if self.is_shutdown: + replacement_connection.close() + self._is_replacing = False + return + self._connections[replacement_connection.features.shard_id] = replacement_connection + self._is_replacing = False + with self._stream_available_condition: + self._stream_available_condition.notify() def shutdown(self): log.debug("Shutting down connections to %s", self.host) @@ -676,17 +773,17 @@ def disable_advanced_shard_aware(self, secs): log.warning("disabling advanced_shard_aware for %i seconds, could be that this client is behind NAT?", secs) self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until) - def _get_shard_aware_endpoint(self): + def _get_shard_aware_endpoint(self, endpoint=None): + endpoint = self.endpoint if endpoint is None else endpoint if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ self._session.cluster.shard_aware_options.disable_shardaware_port: return None - endpoint = None if self._session.cluster.ssl_options and self.host.sharding_info.shard_aware_port_ssl: - endpoint = copy.copy(self.endpoint) + endpoint = copy.copy(endpoint) endpoint._port = self.host.sharding_info.shard_aware_port_ssl elif self.host.sharding_info.shard_aware_port: - endpoint = copy.copy(self.endpoint) + endpoint = copy.copy(endpoint) endpoint._port = self.host.sharding_info.shard_aware_port return endpoint @@ -709,140 +806,163 @@ def _open_connection_to_missing_shard(self, shard_id): the smaller the chance that further connections will be assigned to that shard. """ - with self._lock: - if self.is_shutdown: + expected_endpoint = self.endpoint + try: + with self._lock: + if self.is_shutdown: + return + + current_endpoint = _current_host_endpoint(self.host) + if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() return - shard_aware_endpoint = self._get_shard_aware_endpoint() - log.debug("shard_aware_endpoint=%r", shard_aware_endpoint) - if shard_aware_endpoint: - try: - conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, - shard_id=shard_id, - total_shards=self.host.sharding_info.shards_count) - conn.original_endpoint = self.endpoint - except Exception as exc: - log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) - raise - else: - conn = self._session.cluster.connection_factory(self.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) - log.debug( - "Received a connection %s for shard_id=%i on host %s", - id(conn), - conn.features.shard_id if conn.features.shard_id is not None else -1, - self.host) - if self.is_shutdown: - log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn)) - conn.close() - return + shard_aware_endpoint = self._get_shard_aware_endpoint(expected_endpoint) + log.debug("shard_aware_endpoint=%r", shard_aware_endpoint) + if shard_aware_endpoint: + try: + conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, + shard_id=shard_id, + total_shards=self.host.sharding_info.shards_count) + conn.original_endpoint = expected_endpoint + except Exception as exc: + log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) + raise + else: + conn = self._session.cluster.connection_factory(expected_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) - if shard_aware_endpoint and shard_id != conn.features.shard_id: - # connection didn't land on expected shared - # assuming behind a NAT, disabling advanced shard aware for a while - self.disable_advanced_shard_aware(10 * 60) + current_endpoint = _current_host_endpoint(self.host) + if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + conn.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return - old_conn = self._connections.get(conn.features.shard_id) - if old_conn is None or old_conn.orphaned_threshold_reached: log.debug( - "New connection (%s) created to shard_id=%i on host %s", + "Received a connection %s for shard_id=%i on host %s", id(conn), - conn.features.shard_id, - self.host - ) - old_conn = None - with self._lock: - is_shutdown = self.is_shutdown - if not is_shutdown: - if conn.features.shard_id in self._connections: - # Move the current connection to the trash and use the new one from now on - old_conn = self._connections[conn.features.shard_id] + conn.features.shard_id if conn.features.shard_id is not None else -1, + self.host) + if self.is_shutdown: + log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn)) + conn.close() + return + + if shard_aware_endpoint and shard_id != conn.features.shard_id: + # connection didn't land on expected shared + # assuming behind a NAT, disabling advanced shard aware for a while + self.disable_advanced_shard_aware(10 * 60) + + old_conn = self._connections.get(conn.features.shard_id) + if old_conn is None or old_conn.orphaned_threshold_reached: + log.debug( + "New connection (%s) created to shard_id=%i on host %s", + id(conn), + conn.features.shard_id, + self.host + ) + old_conn = None + with self._lock: + is_shutdown = self.is_shutdown + if not is_shutdown: + if conn.features.shard_id in self._connections: + # Move the current connection to the trash and use the new one from now on + old_conn = self._connections[conn.features.shard_id] + log.debug( + "Replacing overloaded connection (%s) with (%s) for shard %i for host %s", + id(old_conn), + id(conn), + conn.features.shard_id, + self.host + ) + if self._keyspace: + conn.set_keyspace_blocking(self._keyspace) + self._connections[conn.features.shard_id] = conn + + if is_shutdown: + conn.close() + return + + if old_conn is not None: + remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) + if remaining == 0: log.debug( - "Replacing overloaded connection (%s) with (%s) for shard %i for host %s", + "Immediately closing the old connection (%s) for shard %i on host %s", id(old_conn), - id(conn), - conn.features.shard_id, + old_conn.features.shard_id, self.host ) - if self._keyspace: - conn.set_keyspace_blocking(self._keyspace) - self._connections[conn.features.shard_id] = conn - - if is_shutdown: - conn.close() - return - - if old_conn is not None: - remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) - if remaining == 0: - log.debug( - "Immediately closing the old connection (%s) for shard %i on host %s", - id(old_conn), - old_conn.features.shard_id, - self.host - ) - old_conn.close() - else: + old_conn.close() + else: + log.debug( + "Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining", + id(old_conn), + old_conn.features.shard_id, + self.host, + remaining, + ) + with self._lock: + is_shutdown = self.is_shutdown + if not is_shutdown: + self._trash.add(old_conn) + if is_shutdown: + conn.close() + num_missing_or_needing_replacement = self.num_missing_or_needing_replacement + log.debug( + "Connected to %s/%i shards on host %s (%i missing or needs replacement)", + len(self._connections), + self.host.sharding_info.shards_count, + self.host, + num_missing_or_needing_replacement + ) + if num_missing_or_needing_replacement == 0: log.debug( - "Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining", - id(old_conn), - old_conn.features.shard_id, + "All shards of host %s have at least one connection, closing %i excess connections", self.host, - remaining, + len(self._excess_connections) ) - with self._lock: - is_shutdown = self.is_shutdown - if not is_shutdown: - self._trash.add(old_conn) - if is_shutdown: - conn.close() - num_missing_or_needing_replacement = self.num_missing_or_needing_replacement - log.debug( - "Connected to %s/%i shards on host %s (%i missing or needs replacement)", - len(self._connections), - self.host.sharding_info.shards_count, - self.host, - num_missing_or_needing_replacement - ) - if num_missing_or_needing_replacement == 0: + self._close_excess_connections() + elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0: log.debug( - "All shards of host %s have at least one connection, closing %i excess connections", - self.host, - len(self._excess_connections) + "All shards are already covered, closing newly opened excess connection %s for host %s", + id(self), + self.host ) - self._close_excess_connections() - elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0: - log.debug( - "All shards are already covered, closing newly opened excess connection %s for host %s", - id(self), - self.host - ) - conn.close() - else: - if len(self._excess_connections) >= self._excess_connection_limit: + conn.close() + else: + if len(self._excess_connections) >= self._excess_connection_limit: + log.debug( + "After connection %s is created excess connection pool size limit (%i) reached for host %s, closing all %i of them", + id(conn), + self._excess_connection_limit, + self.host, + len(self._excess_connections) + ) + self._close_excess_connections() + log.debug( - "After connection %s is created excess connection pool size limit (%i) reached for host %s, closing all %i of them", + "Putting a connection %s to shard %i to the excess pool of host %s", id(conn), - self._excess_connection_limit, - self.host, - len(self._excess_connections) + conn.features.shard_id, + self.host ) - self._close_excess_connections() - - log.debug( - "Putting a connection %s to shard %i to the excess pool of host %s", - id(conn), - conn.features.shard_id, - self.host - ) - close_connection = False - with self._lock: - if self.is_shutdown: - close_connection = True - else: - self._excess_connections.add(conn) - if close_connection: - conn.close() - self._connecting.discard(shard_id) + close_connection = False + with self._lock: + if self.is_shutdown: + close_connection = True + else: + self._excess_connections.add(conn) + if close_connection: + conn.close() + finally: + self._connecting.discard(shard_id) def _open_connections_for_all_shards(self, skip_shard_id=None): """ @@ -867,7 +987,7 @@ def _open_connections_for_all_shards(self, skip_shard_id=None): self._trash = set() if trash_conns is not None: - for conn in self._trash: + for conn in trash_conns: conn.close() def _set_keyspace_for_all_conns(self, keyspace, callback): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 4965f8980e..f3d379eec6 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -25,8 +25,8 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.connection import ConnectionException, DefaultEndPoint -from cassandra.pool import Host +from cassandra.connection import ClientRoutesEndPoint, ConnectionException, DefaultEndPoint, SniEndPoint +from cassandra.pool import Host, _HostReconnectionHandler from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools @@ -263,6 +263,27 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_reconnector_connection_factory_recomputes_authenticator_after_endpoint_swap(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + auth_provider = Mock() + auth_provider.new_authenticator.side_effect = lambda address: 'auth-%s' % (address,) + + with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: + cluster = Cluster(auth_provider=auth_provider) + host = Host(old_endpoint, SimpleConvictionPolicy) + connection_factory = cluster._make_connection_factory(host) + handler = _HostReconnectionHandler( + host, connection_factory, False, Mock(), Mock(), Mock(), iter([0]), + host.get_and_set_reconnection_handler, new_handler=None) + + host.endpoint = new_endpoint + conn = handler.try_reconnect() + + assert conn == 'connection' + assert factory.call_args.args[0] == new_endpoint + assert factory.call_args.kwargs['authenticator'] == 'auth-127.0.0.2' + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -470,6 +491,32 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_update_created_pools_replaces_pool_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + session._pools[host] = stale_pool + created_pools = [] + + host.endpoint = DefaultEndPoint("127.0.0.2") + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + futures = session.update_created_pools() + + assert len(futures) == 1 + executor.run_next() + + assert created_pools[0].endpoint == host.endpoint + assert session._pools[host] is created_pools[0] + stale_pool.shutdown.assert_called_once_with() + def test_removed_in_flight_pool_is_not_published(self): host = self._make_host("127.0.0.1") cluster, session, executor = self._make_cluster_and_session([host]) @@ -770,6 +817,25 @@ def set_keyspace(keyspace, callback): class HostStateRaceTest(unittest.TestCase): + class _EndpointSwapOnFirstExitLock(object): + + def __init__(self, host, new_endpoint): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._exits = 0 + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self._host.endpoint = self._new_endpoint + + @staticmethod def _make_host(): return Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) @@ -868,6 +934,56 @@ def test_reserved_down_handling_after_endpoint_swap_only_removes_stale_pool(self cluster._start_reconnector.assert_not_called() assert self._state(cluster, host).down_epoch is None + def test_reserved_down_handling_after_endpoint_swap_removes_stale_pool(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host.endpoint = old_endpoint + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + host.lock = self._EndpointSwapOnFirstExitLock(host, new_endpoint) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_endpoint_match_preserves_endpoint_specific_identity(self): + proxy_endpoint = SniEndPoint("proxy.example.com", "node-a", port=9042) + other_proxy_endpoint = SniEndPoint("proxy.example.com", "node-b", port=9042) + assert not Cluster._endpoints_match(proxy_endpoint, other_proxy_endpoint) + + old_endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + def test_auth_failure_quarantine_preserves_endpoint_specific_identity(self): + host = self._make_host() + host.endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + + Cluster._set_non_retryable_auth_failure(host, True) + host.endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + + assert not Cluster._has_non_retryable_auth_failure(host) + def test_noop_down_during_up_handling_does_not_supersede_up(self): pool_future = Future() session = Mock() @@ -1243,6 +1359,24 @@ def test_old_up_callback_does_not_clear_replayed_up_handling(self): assert host.is_up assert state.up_epoch is None + def test_stale_reconnector_success_does_not_clear_newer_reconnector(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + cluster = self._make_cluster() + host = Host(old_endpoint, SimpleConvictionPolicy) + host.endpoint = new_endpoint + new_reconnector = Mock() + host._reconnection_handler = new_reconnector + connection = Mock(endpoint=old_endpoint) + handler = _HostReconnectionHandler( + host, Mock(return_value=connection), False, Mock(), cluster.on_up, Mock(), + iter([0]), host.get_and_set_reconnection_handler, new_handler=None) + + handler.run() + + assert host._reconnection_handler is new_reconnector + assert not host.is_up + def test_superseded_up_cleanup_preserves_replacement_host_pool(self): stale_host = self._make_host() replacement_host = self._make_host() @@ -1462,6 +1596,32 @@ def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): assert state.up_epoch is None assert state.pending_up_epoch is None + def test_later_down_before_worker_runs_does_not_skip_pool_cleanup(self): + executor = _QueuedExecutor() + host = self._make_host() + host.set_up() + pool = Mock() + pool.host = host + pool.endpoint = host.endpoint + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster.executor = executor + cluster.metadata = Mock() + cluster.metadata.all_hosts.return_value = [host] + session.cluster = cluster + session._profile_manager = cluster.profile_manager + + Cluster.on_down(cluster, host, is_host_addition=False) + Cluster.on_up(cluster, host) + Cluster.on_down(cluster, host, is_host_addition=False) + + future = executor.run_next() + + assert future.exception() is None + assert session._pools == {} + pool.shutdown.assert_called_once_with() + def test_up_signal_waits_until_submitted_down_handling_finishes(self): executor = _QueuedExecutor() events = [] diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..9b06eaadb0 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -19,9 +19,11 @@ from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import (Cluster, ControlConnection, _Scheduler, ProfileManager, + EXEC_PROFILE_DEFAULT, ExecutionProfile) from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import (EndPoint, DefaultEndPoint, + DefaultEndPointFactory, SniEndPoint) from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -111,6 +113,7 @@ def __init__(self): self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(RoundRobinPolicy()) self.endpoint_factory = DefaultEndPointFactory().configure(self) self.ssl_options = None + self.down_expected_endpoint = None def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None): host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack, host_id=host_id) @@ -121,11 +124,13 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, def remove_host(self, host): pass - def on_up(self, host): + def on_up(self, host, expected_endpoint=None): pass - def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + def on_down(self, host, is_host_addition, expect_host_to_be_down=False, + expected_endpoint=None): self.down_host = host + self.down_expected_endpoint = expected_endpoint def _node_meta_results(local_results, peer_results): @@ -495,7 +500,8 @@ def test_handle_status_change(self): self.cluster.scheduler.reset_mock() self.control_connection._handle_status_change(event) host = self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.0')) - self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host) + self.cluster.scheduler.schedule_unique.assert_called_once_with( + ANY, self.cluster.on_up, host, expected_endpoint=host.endpoint) self.cluster.scheduler.schedule.reset_mock() event = { @@ -513,6 +519,33 @@ def test_handle_status_change(self): self.control_connection._handle_status_change(event) host = self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.0')) assert host is self.cluster.down_host + assert host.endpoint == self.cluster.down_expected_endpoint + + def test_handle_status_change_preserves_endpoint_identity(self): + host = self.cluster.metadata.hosts['uuid1'] + old_endpoint = SniEndPoint("192.168.1.0", "node-a", 9042) + new_endpoint = SniEndPoint("192.168.1.0", "node-b", 9042) + host.endpoint = old_endpoint + + event = { + 'change_type': 'UP', + 'address': ('192.168.1.0', 9042) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_status_change(event) + + # A raw event address would accept the replacement endpoint here. + assert Cluster._endpoints_match(new_endpoint, event['address']) + assert not Cluster._endpoints_match(new_endpoint, old_endpoint) + self.cluster.scheduler.schedule_unique.assert_called_once_with( + ANY, self.cluster.on_up, host, expected_endpoint=old_endpoint) + + self.cluster.on_down = Mock() + event['change_type'] = 'DOWN' + self.control_connection._handle_status_change(event) + + self.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=old_endpoint) def test_handle_schema_change(self): diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..10d0ce50f4 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -23,8 +23,8 @@ from threading import Thread, Event, Lock from unittest.mock import Mock, NonCallableMagicMock, MagicMock -from cassandra.cluster import Session, ShardAwareOptions -from cassandra.connection import Connection +from cassandra.cluster import Cluster, Session, ShardAwareOptions +from cassandra.connection import ClientRoutesEndPoint, Connection, DefaultEndPoint from cassandra.pool import HostConnection from cassandra.pool import Host, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy @@ -133,6 +133,7 @@ def test_spawn_when_at_max(self): def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False) @@ -153,6 +154,7 @@ def test_return_defunct_connection(self): def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False, @@ -174,6 +176,8 @@ def test_return_defunct_connection_on_down_host(self): if self.PoolImpl is HostConnection: # on shard aware implementation we use submit function regardless assert host.signal_connection_failure.call_args + session.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=pool.endpoint) assert session.submit.called else: assert not session.submit.called @@ -182,6 +186,7 @@ def test_return_defunct_connection_on_down_host(self): def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100, signaled_error=False, orphaned_threshold_reached=False) @@ -200,6 +205,64 @@ def test_return_closed_connection(self): assert session.submit.call_args assert not pool.is_shutdown + def test_return_defunct_connection_after_endpoint_swap_is_ignored(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Mock(spec=Host, address='ip1') + host.endpoint = old_endpoint + host.lock = Lock() + session = self.make_session() + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + host.endpoint = new_endpoint + conn.is_defunct = True + host.signal_connection_failure.return_value = True + pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + conn.close.assert_called_once_with() + assert not pool.is_shutdown + + def test_return_defunct_connection_after_client_route_endpoint_port_swap_is_ignored(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), '127.0.0.1', original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), '127.0.0.1', original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host = Mock(spec=Host, address='ip1') + host.endpoint = old_endpoint + host.lock = Lock() + session = self.make_session() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + host.endpoint = new_endpoint + conn.is_defunct = True + host.signal_connection_failure.return_value = True + pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + conn.close.assert_called_once_with() + assert not pool.is_shutdown + def test_host_instantiations(self): """ Ensure Host fails if not initialized properly From 551e587b3d692a7fb3e97a30278fc2bf131b0ccb Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 13:16:09 -0400 Subject: [PATCH 05/29] 1 --- cassandra/cluster.py | 38 +++++----- tests/unit/test_cluster.py | 138 ++++++++++++++++++++++++++----------- 2 files changed, 117 insertions(+), 59 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index fafc741e39..fd51d03cbd 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2590,6 +2590,26 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, if self.is_shutdown: return + if (self._discount_down_events and expected_endpoint is None and + self.profile_manager.distance(host) != HostDistance.IGNORED): + with host.lock: + host_endpoint = host.endpoint + connected = False + for session in tuple(self.sessions): + # Host equality is endpoint-based; scan by identity to avoid + # hiding the live pool behind a stale equal key. Do not hold + # host.lock while taking session._lock; update_created_pools() + # takes the locks in the opposite order. + pool = session._get_pool_by_host_identity( + host, expected_endpoint=host_endpoint) + if pool is not None and pool.open_count > 0: + connected = True + break + if connected: + with host.lock: + if self._endpoints_match(host.endpoint, host_endpoint): + return + with host.lock: if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Ignoring stale down signal for host %s; endpoint changed from %s", @@ -2599,24 +2619,6 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, was_up = host.is_up state = self._get_host_liveness_state(host) - # ignore down signals if we have open pools to the host - # this is to avoid closing pools when a control connection host became isolated - # endpoint-aware cleanup still needs to run - if (self._discount_down_events and expected_endpoint is None and - self.profile_manager.distance(host) != HostDistance.IGNORED): - host_endpoint = host.endpoint - connected = False - for session in tuple(self.sessions): - # Host equality is endpoint-based; scan by identity to avoid - # hiding the live pool behind a stale equal key. - pool = session._get_pool_by_host_identity( - host, expected_endpoint=host_endpoint) - if pool is not None and pool.open_count > 0: - connected = True - break - if connected: - return - if not expect_host_to_be_down: if was_up is False: if state.pending_up_epoch is not None: diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index f3d379eec6..d3a7b70867 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -17,7 +17,7 @@ import socket from concurrent.futures import Future -from threading import Lock, RLock +from threading import Event, Lock, RLock, Thread from unittest.mock import patch, Mock, ANY import uuid @@ -271,7 +271,7 @@ def test_reconnector_connection_factory_recomputes_authenticator_after_endpoint_ with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: cluster = Cluster(auth_provider=auth_provider) - host = Host(old_endpoint, SimpleConvictionPolicy) + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) connection_factory = cluster._make_connection_factory(host) handler = _HostReconnectionHandler( host, connection_factory, False, Mock(), Mock(), Mock(), iter([0]), @@ -670,11 +670,11 @@ def make_pool(host, distance, pool_session, endpoint=None): executor.run_next() - assert future.result() is True + assert future.result() is False host_connection.assert_called_once_with( host, HostDistance.LOCAL, session, endpoint=old_endpoint) assert created_pools[0].endpoint == old_endpoint - created_pools[0].shutdown.assert_not_called() + created_pools[0].shutdown.assert_called_once_with() def test_add_or_renew_pool_auth_failure_reports_creation_endpoint(self): host = self._make_host("127.0.0.1") @@ -862,6 +862,62 @@ def _make_session_with_pool(host, pool): session.submit = _ImmediateExecutor().submit return session + def test_discount_down_event_does_not_hold_host_lock_while_scanning_pools(self): + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = old_endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + # Mutating the endpoint after insertion forces an identity scan under + # session._lock; the down path must not hold host.lock during that scan. + host.endpoint = DefaultEndPoint("127.0.0.2") + entered_distance = Event() + release_distance = Event() + blocked_on_host = [] + thread_errors = [] + + def distance(_host): + entered_distance.set() + release_distance.wait(1) + return HostDistance.LOCAL + + def hold_session_then_try_host(): + with session._lock: + entered_distance.wait(1) + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + release_distance.set() + + def run_on_down(): + try: + Cluster.on_down(cluster, host, is_host_addition=False) + except Exception as exc: + thread_errors.append(exc) + + cluster.profile_manager.distance.side_effect = distance + worker = Thread(target=hold_session_then_try_host) + runner = Thread(target=run_on_down) + worker.start() + runner.start() + worker.join(2) + runner.join(2) + + assert not thread_errors + assert not runner.is_alive() + assert blocked_on_host == [False] + @staticmethod def _state(cluster, host): return cluster._get_host_liveness_state(host) @@ -1035,7 +1091,8 @@ def test_newer_forced_down_during_up_handling_is_preserved(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) assert state.epoch > first_up_epoch @@ -1045,7 +1102,7 @@ def test_newer_forced_down_during_up_handling_is_preserved(self): pool_future.set_result(True) listener.on_up.assert_not_called() - assert session.remove_pool.call_count == 2 + assert session.remove_pool.call_count == 1 assert not host.is_up assert state.up_epoch is None @@ -1068,7 +1125,8 @@ def test_stale_failed_up_callback_does_not_cleanup_newer_down(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) listener.on_up.assert_not_called() @@ -1100,11 +1158,12 @@ def force_down_before_cleanup(message, *args, **kwargs): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with( host, False, expected_down_epoch=ANY) - assert session.remove_pool.call_count == 2 + assert session.remove_pool.call_count == 1 listener.on_up.assert_not_called() assert not host.is_up assert state.up_epoch is None @@ -1146,10 +1205,10 @@ def test_forced_down_during_up_handling_is_not_hidden_by_reconnector(self): host._reconnection_handler = old_reconnector original_get_reconnector = Cluster._get_reconnector_for_current_up_handling - def force_down_before_reconnector_is_cleared(h, up_epoch): + def force_down_before_reconnector_is_cleared(h, up_epoch, **kwargs): Cluster.on_down( cluster, h, is_host_addition=False, expect_host_to_be_down=True) - return original_get_reconnector(cluster, h, up_epoch) + return original_get_reconnector(cluster, h, up_epoch, **kwargs) cluster._get_reconnector_for_current_up_handling = Mock( side_effect=force_down_before_reconnector_is_cleared) @@ -1158,12 +1217,13 @@ def force_down_before_reconnector_is_cleared(h, up_epoch): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) cluster.profile_manager.on_up.assert_not_called() cluster.control_connection.on_up.assert_not_called() - old_reconnector.cancel.assert_not_called() + old_reconnector.cancel.assert_called_once_with() assert not host.is_up assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).down_epoch is None @@ -1181,7 +1241,8 @@ def test_forced_down_while_reconnecting_runs_new_down_handling(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) assert self._state(cluster, host).down_epoch is None @@ -1252,11 +1313,10 @@ def test_up_during_down_superseding_in_flight_up_is_replayed(self): assert host.is_up assert state.up_epoch is None - def test_superseded_up_cleanup_precedes_replayed_up_pool_creation(self): + def test_stale_superseded_up_cleanup_does_not_run_after_newer_down(self): first_pool_future = Future() - second_pool_future = Future() session = Mock() - session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + session.add_or_renew_pool.return_value = first_pool_future cluster = self._make_cluster(session=session) cluster._prepare_all_queries = Mock() cluster.profile_manager.distance.return_value = HostDistance.LOCAL @@ -1269,28 +1329,19 @@ def test_superseded_up_cleanup_precedes_replayed_up_pool_creation(self): cleanup_calls = [] - def signal_up_during_first_cleanup(h, **kwargs): - if cleanup_calls: - return None + def signal_up_during_stale_cleanup(h, **kwargs): cleanup_calls.append(h) - Cluster.on_up(cluster, h) - assert session.add_or_renew_pool.call_count == 1 - assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch return None - session.remove_pool.side_effect = signal_up_during_first_cleanup + session.remove_pool.side_effect = signal_up_during_stale_cleanup first_pool_future.set_result(True) - assert cleanup_calls == [host] - assert session.add_or_renew_pool.call_count == 2 - assert self._state(cluster, host).up_epoch == self._state(cluster, host).epoch - assert self._state(cluster, host).pending_up_epoch is None - - second_pool_future.set_result(True) - - assert host.is_up + assert cleanup_calls == [] + assert session.add_or_renew_pool.call_count == 1 assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + assert not host.is_up def test_sync_up_failure_replays_queued_up(self): session = Mock() @@ -1321,7 +1372,8 @@ def queue_up_then_fail(h): assert cluster.profile_manager.on_up.call_count == 2 cluster.control_connection.on_up.assert_called_once_with(host) session.add_or_renew_pool.assert_called_once_with( - host, is_host_addition=False) + host, is_host_addition=False, + allow_retry_after_auth_failure=True) assert host.is_up assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).pending_up_epoch is None @@ -1363,7 +1415,7 @@ def test_stale_reconnector_success_does_not_clear_newer_reconnector(self): old_endpoint = DefaultEndPoint('127.0.0.1') new_endpoint = DefaultEndPoint('127.0.0.2') cluster = self._make_cluster() - host = Host(old_endpoint, SimpleConvictionPolicy) + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) host.endpoint = new_endpoint new_reconnector = Mock() host._reconnection_handler = new_reconnector @@ -1424,7 +1476,8 @@ def test_down_during_up_listener_is_handled(self): listener.on_up.assert_called_once_with(host) cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) assert not host.is_up @@ -1444,7 +1497,8 @@ def test_current_down_handling_still_removes_pools_and_reconnects(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) assert self._state(cluster, host).down_epoch is None @@ -1546,7 +1600,8 @@ def test_on_up_queues_after_down_is_submitted_before_worker_runs(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) cluster.profile_manager.on_up.assert_called_once_with(host) @@ -1586,7 +1641,8 @@ def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - session.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) cluster.profile_manager.on_up.assert_called_once_with(host) @@ -1635,13 +1691,13 @@ def test_up_signal_waits_until_submitted_down_handling_finishes(self): cluster.profile_manager.on_down.side_effect = lambda h: events.append("profile_down") cluster.control_connection.on_down.side_effect = lambda h: events.append("control_down") - session.on_down.side_effect = lambda h: events.append("session_down") + session.on_down.side_effect = lambda h, **kwargs: events.append("session_down") listener.on_down.side_effect = lambda h: events.append("listener_down") cluster._start_reconnector.side_effect = lambda h, is_host_addition, **kwargs: events.append("reconnector") session.remove_pool.side_effect = lambda h, **kwargs: events.append("remove_pool") cluster.profile_manager.on_up.side_effect = lambda h: events.append("profile_up") cluster.control_connection.on_up.side_effect = lambda h: events.append("control_up") - session.add_or_renew_pool.side_effect = lambda h, is_host_addition: events.append("add_pool") + session.add_or_renew_pool.side_effect = lambda h, is_host_addition, **kwargs: events.append("add_pool") Cluster.on_down(cluster, host, is_host_addition=False) Cluster.on_up(cluster, host) @@ -1795,7 +1851,7 @@ def test_down_after_pending_up_pop_invalidates_replay(self): with host.lock: pending_up_epoch = cluster._pop_pending_node_up_if_ready(host) - assert pending_up_epoch == 2 + assert pending_up_epoch == (2, None) assert state.pending_up_epoch == 2 Cluster.on_down(cluster, host, is_host_addition=False) From 307176c76961c4f2bac5d361728d3bc99b78d10e Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 13:56:19 -0400 Subject: [PATCH 06/29] cluster: fence client-route endpoint identity --- cassandra/cluster.py | 49 +++++++++--------- tests/unit/test_cluster.py | 20 ++++++++ tests/unit/test_control_connection.py | 73 ++++++++++++++++++++++++++- 3 files changed, 118 insertions(+), 24 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index fd51d03cbd..744a410769 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4890,32 +4890,32 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids.add(host_id) found_endpoints.add(endpoint) - host = self._cluster.metadata.get_host(endpoint) + host_by_endpoint = self._cluster.metadata.get_host(endpoint) + host_by_id = self._cluster.metadata.get_host_by_host_id(host_id) + host = host_by_id or host_by_endpoint datacenter = row.get("data_center") rack = row.get("rack") - if host is None: - host = self._cluster.metadata.get_host_by_host_id(host_id) - if host and host.endpoint != endpoint: - log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() - with host.lock: - old_endpoint = host.endpoint - self._cluster.on_down( - host, is_host_addition=False, expect_host_to_be_down=True, - expected_endpoint=old_endpoint) + if host is not None and not self._cluster._endpoints_match(host.endpoint, endpoint): + log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + with host.lock: + old_endpoint = host.endpoint + self._cluster.on_down( + host, is_host_addition=False, expect_host_to_be_down=True, + expected_endpoint=old_endpoint) - with host.lock: - if host.endpoint != old_endpoint: - log.debug("[control connection] Not updating host ip from %s to %s for (%s); " - "endpoint changed to %s", - old_endpoint, endpoint, host_id, host.endpoint) - continue - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) + with host.lock: + if not self._cluster._endpoints_match(host.endpoint, old_endpoint): + log.debug("[control connection] Not updating host ip from %s to %s for (%s); " + "endpoint changed to %s", + old_endpoint, endpoint, host_id, host.endpoint) + continue + host.endpoint = endpoint + self._cluster.metadata.update_host(host, old_endpoint) + self._cluster.on_up(host) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) @@ -5247,7 +5247,8 @@ def _signal_error(self): # that errors have already been reported, so we're fine if host: self._cluster.signal_connection_failure( - host, self._connection.last_error, is_host_addition=False) + host, self._connection.last_error, is_host_addition=False, + expected_endpoint=self._connection.endpoint) return # if the connection is not defunct or the host already left, reconnect @@ -5340,6 +5341,8 @@ def schedule_unique(self, delay, fn, *args, **kwargs): def _freeze_task_arg(value): if isinstance(value, Host): return (Host, id(value)) + if isinstance(value, EndPoint): + return (EndPoint, Cluster._endpoint_key(value)) if isinstance(value, tuple): return tuple(_Scheduler._freeze_task_arg(item) for item in value) if isinstance(value, list): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index d3a7b70867..3d9ddfc6b2 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -300,6 +300,26 @@ def test_event_delay_timing(self, *_): sched.schedule(0, lambda: None) sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" + @patch('cassandra.cluster._Scheduler.run') # don't actually run the thread + def test_schedule_unique_keeps_client_route_events_for_distinct_ports(self, *_): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host = Host(new_endpoint, SimpleConvictionPolicy, host_id=host_id) + scheduled_fn = Mock() + + sched = _Scheduler(Mock()) + sched.schedule_unique( + 30, scheduled_fn, host, expected_endpoint=old_endpoint) + sched.schedule_unique( + 30, scheduled_fn, host, expected_endpoint=new_endpoint) + + assert len(sched._scheduled_tasks) == 2 + class SessionPoolRaceTest(unittest.TestCase): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 9b06eaadb0..0f993c8d7e 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import uuid from concurrent.futures import ThreadPoolExecutor from unittest.mock import Mock, ANY, call @@ -21,8 +22,10 @@ from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import (Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile) +from cassandra.metadata import Metadata from cassandra.pool import Host -from cassandra.connection import (EndPoint, DefaultEndPoint, +from cassandra.connection import (ClientRoutesEndPoint, ConnectionException, + EndPoint, DefaultEndPoint, DefaultEndPointFactory, SniEndPoint) from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -124,6 +127,9 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, def remove_host(self, host): pass + def _endpoints_match(self, endpoint, expected_endpoint): + return Cluster._endpoints_match(endpoint, expected_endpoint) + def on_up(self, host, expected_endpoint=None): pass @@ -383,6 +389,71 @@ def test_change_ip(self): assert 3 == len(self.cluster.metadata.all_hosts()) + def test_change_client_route_endpoint_when_only_port_changes(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + self.cluster.endpoint_factory = Mock() + self.cluster.endpoint_factory.create.return_value = new_endpoint + self.cluster.on_down = Mock() + self.cluster.on_up = Mock() + self.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "rpc_port", "peer", "data_center", "rack", "host_id"], + [["127.0.0.1", 9142, "127.0.0.1", "dc1", "rack1", host_id]])) + + self.control_connection._refresh_node_list_and_token_map( + self.connection, preloaded_results=preloaded_results) + + assert Cluster._endpoints_match(host.endpoint, new_endpoint) + self.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expect_host_to_be_down=True, + expected_endpoint=old_endpoint) + self.cluster.on_up.assert_called_once_with(host) + + def test_stale_control_connection_failure_is_endpoint_fenced(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=host_id) + host.set_up() + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + host.endpoint = new_endpoint + self.cluster.metadata.update_host(host, old_endpoint) + assert self.cluster.metadata.get_host(old_endpoint) is host + + self.connection.endpoint = old_endpoint + self.connection.is_defunct = True + self.connection.last_error = ConnectionException( + "stale control connection failed", endpoint=old_endpoint) + self.cluster.signal_connection_failure = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, self.connection.last_error, is_host_addition=False, + expected_endpoint=old_endpoint) def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): """ From eca10497b356d795ee3bb7a3c4c9ec6c2c2c0ac8 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 14:39:50 -0400 Subject: [PATCH 07/29] cluster: fix stale up and replace races --- cassandra/cluster.py | 3 ++- cassandra/pool.py | 8 ++++++- tests/unit/test_cluster.py | 28 ++++++++++++++++++++++ tests/unit/test_host_connection_pool.py | 32 +++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 744a410769..a1d098a035 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2411,8 +2411,9 @@ def _clear_stale_reconnector(): host, is_host_addition=False, allow_retry_after_auth_failure=True) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + for future in tuple(futures): + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: diff --git a/cassandra/pool.py b/cassandra/pool.py index 158a2f5445..676f54d454 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -705,7 +705,13 @@ def _replace(self, connection): return if keyspace: - replacement_connection.set_keyspace_blocking(keyspace) + try: + replacement_connection.set_keyspace_blocking(keyspace) + except Exception: + log.warning("Failed reconnecting %s. Retrying." % (expected_endpoint,)) + replacement_connection.close() + self._session.submit(self._replace, connection) + return with self._lock: if self.is_shutdown: diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 3d9ddfc6b2..a30b4d84b0 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1090,6 +1090,34 @@ def test_noop_down_during_up_handling_does_not_supersede_up(self): assert host.is_up assert state.up_epoch is None + def test_on_up_waits_for_all_pool_futures_when_one_is_already_done(self): + completed_pool_future = Future() + completed_pool_future.set_result(True) + pending_pool_future = Future() + completed_session = Mock() + completed_session.add_or_renew_pool.return_value = completed_pool_future + pending_session = Mock() + pending_session.add_or_renew_pool.return_value = pending_pool_future + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.sessions = [completed_session, pending_session] + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + + assert not host.is_up + listener.on_up.assert_not_called() + + pending_pool_future.set_result(False) + + assert not host.is_up + listener.on_up.assert_not_called() + cluster._start_reconnector.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=host.endpoint) + def test_newer_forced_down_during_up_handling_is_preserved(self): pool_future = Future() session = Mock() diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 10d0ce50f4..3dbde237a5 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -350,3 +350,35 @@ def mock_connection_factory(self, *args, **kwargs): # Cleanup executor with proper wait session.cluster.executor.shutdown(wait=True) + + def test_replace_retries_when_replacement_keyspace_set_fails(self): + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + initial_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection.set_keyspace_blocking.side_effect = RuntimeError( + "keyspace failed") + session.cluster.connection_factory.side_effect = [ + initial_connection, replacement_connection] + + pool = HostConnection(host, HostDistance.LOCAL, session) + pool._is_replacing = True + + pool._replace(initial_connection) + + assert session.submit.call_count == 1 + submitted_fn, submitted_connection = session.submit.call_args.args + assert submitted_fn == pool._replace + assert submitted_connection is initial_connection From d774e8887ad6d9e49db25b96efee45f3a01763bb Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 15:30:47 -0400 Subject: [PATCH 08/29] response-future: return timeouts to original pool --- cassandra/cluster.py | 4 +- tests/unit/test_response_future.py | 83 +++++++++++++++++++----------- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a1d098a035..96a23b711a 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5482,6 +5482,7 @@ class ResponseFuture(object): _errbacks = None _current_host = None _connection = None + _connection_pool = None _query_retries = 0 _start_time = None _metrics = None @@ -5578,7 +5579,7 @@ def _on_timeout(self, _attempts=0): # Capture connection stats before pool.return_connection() can alter state conn_in_flight = self._connection.in_flight - pool = self.session._get_pool_by_host_identity(self._current_host) + pool = self._connection_pool if pool and not pool.is_shutdown: # Do not return the stream ID to the pool yet. We cannot reuse it # because the node might still be processing the query and will @@ -5679,6 +5680,7 @@ def _query(self, host, message=None, cb=None): else: connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection + self._connection_pool = pool result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..a5e16461d2 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import uuid from collections import deque from threading import RLock @@ -20,7 +21,7 @@ from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion -from cassandra.connection import Connection, ConnectionException +from cassandra.connection import Connection, ConnectionException, DefaultEndPoint from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, @@ -28,8 +29,8 @@ RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, ProtocolHandler) -from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy -from cassandra.pool import NoConnectionsAvailable +from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy, SimpleConvictionPolicy +from cassandra.pool import Host, NoConnectionsAvailable from cassandra.query import SimpleStatement from tests.util import assertEqual, assertIsInstance import pytest @@ -52,7 +53,7 @@ def make_pool(self): def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value = self.make_pool() + session._get_pool_by_host_identity.return_value = self.make_pool() return session def make_response_future(self, session): @@ -66,7 +67,7 @@ def make_mock_response(self, col_names, rows): def test_result_message(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value pool.is_shutdown = False connection = Mock(spec=Connection) @@ -75,7 +76,7 @@ def test_result_message(self): rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -87,7 +88,7 @@ def test_result_message(self): def test_unknown_result_class(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -151,7 +152,7 @@ def test_heartbeat_defunct_deadlock(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()] - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool query = SimpleStatement("SELECT * FROM foo") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) @@ -252,7 +253,7 @@ def test_retry_policy_says_ignore(self): def test_retry_policy_says_retry(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) @@ -266,7 +267,7 @@ def test_retry_policy_says_retry(self): rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -285,13 +286,13 @@ def test_retry_policy_says_retry(self): # it should try again with the same host since this was # an UnavailableException - rf.session._pools.get.assert_called_with(host) + rf.session._get_pool_by_host_identity.assert_called_with(host) pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -300,7 +301,7 @@ def test_retry_with_different_host(self): rf.message.consistency_level = ConsistencyLevel.QUORUM rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) assert ConsistencyLevel.QUORUM == rf.message.consistency_level @@ -319,7 +320,7 @@ def test_retry_with_different_host(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -328,13 +329,13 @@ def test_retry_with_different_host(self): def test_all_retries_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -346,7 +347,7 @@ def test_all_retries_fail(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') result = Mock(spec=IsBootstrappingErrorMessage, info={}) rf._set_result(host, None, None, result) @@ -360,7 +361,7 @@ def test_all_retries_fail(self): def test_exponential_retry_policy_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -368,7 +369,7 @@ def test_exponential_retry_policy_fail(self): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1, retry_policy=ExponentialBackoffRetryPolicy(2)) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -384,7 +385,7 @@ def test_exponential_retry_policy_fail(self): def test_all_pools_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value.is_shutdown = True + session._get_pool_by_host_identity.return_value.is_shutdown = True rf = ResponseFuture(session, Mock(), Mock(), 1) rf.send_request() @@ -399,7 +400,7 @@ def test_first_pool_shutdown(self): pool_shutdown.is_shutdown = True pool_ok = self.make_pool() pool_ok.is_shutdown = True - session._pools.get.side_effect = [pool_shutdown, pool_ok] + session._get_pool_by_host_identity.side_effect = [pool_shutdown, pool_ok] rf = self.make_response_future(session) rf.send_request() @@ -424,7 +425,7 @@ def test_timeout_getting_connection_from_pool(self): connection = Mock(spec=Connection) second_pool.borrow_connection.return_value = (connection, 1) - session._pools.get.side_effect = [first_pool, second_pool] + session._get_pool_by_host_identity.side_effect = [first_pool, second_pool] rf = self.make_response_future(session) rf.send_request() @@ -459,7 +460,7 @@ def test_callback(self): def test_errback(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -508,7 +509,7 @@ def test_multiple_callbacks(self): def test_multiple_errbacks(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -581,7 +582,7 @@ def test_add_callbacks(self): def test_prepared_query_not_found(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -606,7 +607,7 @@ def test_prepared_query_not_found(self): def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -655,7 +656,7 @@ def test_timeout_does_not_release_stream_id(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')] pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), orphaned_request_ids=set(), orphaned_threshold=256, in_flight=3) pool.borrow_connection.return_value = (connection, 1) @@ -675,6 +676,30 @@ def test_timeout_does_not_release_stream_id(self): assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) + def test_timeout_returns_orphan_to_original_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + old_pool = self.make_pool() + replacement_pool = self.make_pool() + connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), + orphaned_request_ids=set(), orphaned_threshold=256, in_flight=1) + old_pool.borrow_connection.return_value = (connection, 1) + session._get_pool_by_host_identity.side_effect = [old_pool, replacement_pool] + + rf = self.make_response_future(session) + rf.send_request() + connection._requests[1] = (connection._handle_options_response, + ProtocolHandler.decode_message, []) + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf._on_timeout() + + replacement_pool.return_connection.assert_not_called() + old_pool.return_connection.assert_called_once_with( + connection, stream_was_orphaned=True) + def test_single_host_query_plan_exhausted_after_one_retry(self): """ Test that when a specific host is provided, the query plan is properly @@ -686,7 +711,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): """ session = self.make_basic_session() pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool # Create a specific host specific_host = Mock() @@ -702,7 +727,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): rf.send_request() # Verify initial request was sent - rf.session._pools.get.assert_called_once_with(specific_host) + rf.session._get_pool_by_host_identity.assert_called_once_with(specific_host) pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) From c74c7096fecf906e34faa9991586bd715d63e6c5 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 20:20:13 -0400 Subject: [PATCH 09/29] 1 --- cassandra/cluster.py | 13 ++++++++++-- tests/unit/test_cluster.py | 33 ++++++++++++++++++++++++++++-- tests/unit/test_response_future.py | 28 ++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 96a23b711a..de0e5196e4 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4303,6 +4303,7 @@ def update_created_pools(self): with self._lock: with host.lock: host_endpoint = host.endpoint + host_is_up = host.is_up pool = self._get_pool_by_host_identity( host, expected_endpoint=host_endpoint) any_pool = pool or self._get_pool_by_host_identity(host) @@ -4320,8 +4321,10 @@ def update_created_pools(self): # on_up() keeps host.is_up False until this future succeeds. future = self._reuse_or_invalidate_pool_creation( host, pool_creation_future) - elif host.is_up in (True, None): + elif host_is_up in (True, None): future = self.add_or_renew_pool(host, False) + elif any_pool is not None: + future = self.remove_pool(host) elif pool_creation_future is not None: future = self._reuse_or_invalidate_pool_creation( host, pool_creation_future) @@ -5662,7 +5665,13 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message - pool = self.session._get_pool_by_host_identity(host) + if isinstance(host, Host): + with host.lock: + expected_endpoint = host.endpoint + pool = self.session._get_pool_by_host_identity( + host, expected_endpoint=expected_endpoint) + else: + pool = self.session._get_pool_by_host_identity(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a30b4d84b0..663408f550 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -323,6 +323,17 @@ def test_schedule_unique_keeps_client_route_events_for_distinct_ports(self, *_): class SessionPoolRaceTest(unittest.TestCase): + class _DuplicatePoolEntries(object): + + def __init__(self, entries): + self._entries = entries + + def __len__(self): + return len(self._entries) + + def items(self): + return list(self._entries) + @staticmethod def _make_host(address): return Host(address, SimpleConvictionPolicy, host_id=uuid.uuid4()) @@ -511,6 +522,22 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_update_created_pools_removes_stale_pool_for_down_host_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + host.set_down() + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = stale_pool + + host.endpoint = DefaultEndPoint("127.0.0.2") + + futures = session.update_created_pools() + + assert len(futures) == 1 + executor.run_next() + assert session._pools == {} + stale_pool.shutdown.assert_called_once_with() + def test_update_created_pools_replaces_pool_after_endpoint_change(self): host = self._make_host("127.0.0.1") old_endpoint = host.endpoint @@ -628,11 +655,13 @@ def test_remove_pool_expected_endpoint_preserves_replacement_pool(self): old_endpoint = host.endpoint cluster, session, executor = self._make_cluster_and_session([host]) stale_pool = self._make_pool(host, HostDistance.LOCAL, session) - session._pools[host] = stale_pool host.endpoint = DefaultEndPoint("127.0.0.2") replacement_pool = self._make_pool(host, HostDistance.LOCAL, session) - session._pools[host] = replacement_pool + session._pools = self._DuplicatePoolEntries([ + (host, stale_pool), + (host, replacement_pool), + ]) assert len(session._pools) == 2 diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index a5e16461d2..07e8f3abdf 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -20,7 +20,7 @@ from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.cluster import Cluster, Session, ResponseFuture, NoHostAvailable, ProtocolVersion from cassandra.connection import Connection, ConnectionException, DefaultEndPoint from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, @@ -700,6 +700,32 @@ def test_timeout_returns_orphan_to_original_pool_after_endpoint_swap(self): old_pool.return_connection.assert_called_once_with( connection, stream_was_orphaned=True) + def test_query_does_not_borrow_stale_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + def test_single_host_query_plan_exhausted_after_one_retry(self): """ Test that when a specific host is provided, the query plan is properly From 5db5154f367e57696badf5aad5f6e9a78e6ba764 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 20:54:11 -0400 Subject: [PATCH 10/29] session: publish pools under endpoint lock --- cassandra/cluster.py | 78 +++++++++++++++++++------------------- tests/unit/test_cluster.py | 46 ++++++++++++++++++++++ 2 files changed, 85 insertions(+), 39 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index de0e5196e4..d9b5d21d73 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4073,46 +4073,46 @@ def callback(pool, errors): else: with host.lock: endpoint_changed = not self._endpoints_match(host.endpoint, creation_endpoint) - if endpoint_changed: - log.debug( - "Discarding stale connection pool for host %s; endpoint changed from %s", - host, creation_endpoint) - self._invalidate_pool_creation(host, expected_endpoint=creation_endpoint) - discard_pool = True - else: - # Rebuild by identity so endpoint hash changes do not - # leave stale pool entries behind. - retained_pools = {} - for pool_host, host_pool in self._pools.items(): - if pool_host is host: - previous_pools.append(host_pool) - else: - retained_pools[pool_host] = host_pool - - # Keep the current metadata host keyed by identity. - metadata_host = host - if isinstance(self.cluster.metadata, Metadata): - metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) - - target_host = metadata_host if metadata_host is not None else host - target_host_matches = False - for pool_host in tuple(retained_pools): - if pool_host is target_host: - target_host_matches = True - elif pool_host == target_host: - previous_pools.append(retained_pools.pop(pool_host)) - - if target_host_matches: - reuse_existing_pool = True + if endpoint_changed: + log.debug( + "Discarding stale connection pool for host %s; endpoint changed from %s", + host, creation_endpoint) + self._invalidate_pool_creation(host, expected_endpoint=creation_endpoint) + discard_pool = True else: - source_host = new_pool.host - if (source_host is not target_host and - target_host.sharding_info is None): - target_host.sharding_info = source_host.sharding_info - new_pool.host = target_host - retained_pools[target_host] = new_pool - self._pools = retained_pools - self._clear_pool_creation(host, creation_epoch) + # Rebuild by identity so endpoint hash changes do not + # leave stale pool entries behind. + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if pool_host is host: + previous_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + + # Keep the current metadata host keyed by identity. + metadata_host = host + if isinstance(self.cluster.metadata, Metadata): + metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) + + target_host = metadata_host if metadata_host is not None else host + target_host_matches = False + for pool_host in tuple(retained_pools): + if pool_host is target_host: + target_host_matches = True + elif pool_host == target_host: + previous_pools.append(retained_pools.pop(pool_host)) + + if target_host_matches: + reuse_existing_pool = True + else: + source_host = new_pool.host + if (source_host is not target_host and + target_host.sharding_info is None): + target_host.sharding_info = source_host.sharding_info + new_pool.host = target_host + retained_pools[target_host] = new_pool + self._pools = retained_pools + self._clear_pool_creation(host, creation_epoch) if reuse_existing_pool: log.debug("Reusing existing connection pool for host %s", host) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 663408f550..c08dc46e59 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -323,6 +323,28 @@ def test_schedule_unique_keeps_client_route_events_for_distinct_ports(self, *_): class SessionPoolRaceTest(unittest.TestCase): + class _EndpointSwapIfPoolUnpublishedOnFirstExitLock(object): + + def __init__(self, host, new_endpoint, pool_is_published): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._pool_is_published = pool_is_published + self._exits = 0 + self.pool_was_published_on_first_exit = None + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self.pool_was_published_on_first_exit = self._pool_is_published() + if not self.pool_was_published_on_first_exit: + self._host.endpoint = self._new_endpoint + class _DuplicatePoolEntries(object): def __init__(self, entries): @@ -522,6 +544,30 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_pool_creation_publishes_before_endpoint_lock_is_released(self): + host = self._make_host("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + host.lock = self._EndpointSwapIfPoolUnpublishedOnFirstExitLock( + host, new_endpoint, lambda: bool(session._pools)) + + executor.run_next() + + assert host.lock.pool_was_published_on_first_exit is True + assert future.result() is True + assert session._pools[host] is created_pools[0] + created_pools[0].shutdown.assert_not_called() + def test_update_created_pools_removes_stale_pool_for_down_host_after_endpoint_change(self): host = self._make_host("127.0.0.1") host.set_down() From 3d20c2404003eaf6d58058d1d0a7fa8cc02b7ec0 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 22:24:36 -0400 Subject: [PATCH 11/29] pool: guard stale cleanup by pool identity --- cassandra/cluster.py | 22 ++++++++++++++-------- cassandra/pool.py | 6 ++++-- tests/unit/test_cluster.py | 24 +++++++++++++++++++++++- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d9b5d21d73..017d266115 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4156,7 +4156,8 @@ def callback(pool, errors): state.future = future return future - def remove_pool(self, host, expected_host=None, expected_endpoint=None): + def remove_pool(self, host, expected_host=None, expected_endpoint=None, + expected_pool=None): removed_pools = [] cleanup_context = _POOL_CLEANUP_EPOCH.get() with self._lock: @@ -4187,11 +4188,12 @@ def remove_pool(self, host, expected_host=None, expected_endpoint=None): if expected_endpoint is not None: remove_all = self._endpoints_match(host.endpoint, expected_endpoint) - if remove_all: - self._invalidate_pool_creation(host) - else: - self._invalidate_pool_creation( - host, expected_endpoint=expected_endpoint) + if expected_pool is None: + if remove_all: + self._invalidate_pool_creation(host) + else: + self._invalidate_pool_creation( + host, expected_endpoint=expected_endpoint) retained_pools = {} for pool_host, host_pool in self._pools.items(): @@ -4201,7 +4203,8 @@ def remove_pool(self, host, expected_host=None, expected_endpoint=None): matches = self._pool_matches_expected( host_pool, expected_host=expected_host, - expected_endpoint=None if remove_all else expected_endpoint) + expected_endpoint=None if remove_all else expected_endpoint, + expected_pool=expected_pool) if matches: removed_pools.append(host_pool) else: @@ -4230,7 +4233,10 @@ def _shutdown_removed_pools(pools): for pool in pools: pool.shutdown() - def _pool_matches_expected(self, pool, expected_host=None, expected_endpoint=None): + def _pool_matches_expected(self, pool, expected_host=None, + expected_endpoint=None, expected_pool=None): + if expected_pool is not None and pool is not expected_pool: + return False if expected_host is not None and pool.host is not expected_host: return False if expected_endpoint is not None: diff --git a/cassandra/pool.py b/cassandra/pool.py index 676f54d454..c96016feba 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -597,7 +597,8 @@ def return_connection(self, connection, stream_was_orphaned=False): # Drop only this stale pool; endpoint reuse may already belong to # a replacement host instance. future = self._session.remove_pool( - self.host, expected_host=self.host, expected_endpoint=self.endpoint) + self.host, expected_host=self.host, + expected_endpoint=self.endpoint, expected_pool=self) if future: future.add_done_callback(lambda f: self._session.update_created_pools()) with self._lock: @@ -645,7 +646,8 @@ def on_orphaned_stream_released(self): def _remove_stale_pool(self, expected_endpoint): future = self._session.remove_pool( - self.host, expected_host=self.host, expected_endpoint=expected_endpoint) + self.host, expected_host=self.host, + expected_endpoint=expected_endpoint, expected_pool=self) if future: future.add_done_callback(lambda f: self._session.update_created_pools()) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index c08dc46e59..e07ad72edb 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -26,7 +26,7 @@ from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT from cassandra.connection import ClientRoutesEndPoint, ConnectionException, DefaultEndPoint, SniEndPoint -from cassandra.pool import Host, _HostReconnectionHandler +from cassandra.pool import Host, HostConnection, _HostReconnectionHandler from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools @@ -747,6 +747,28 @@ def make_pool(host, distance, pool_session, endpoint=None): assert created_pools[0].endpoint == host.endpoint created_pools[0].shutdown.assert_not_called() + def test_stale_host_connection_cleanup_after_endpoint_flip_back_preserves_current_pool(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + replacement_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + + host.endpoint = replacement_endpoint + host.endpoint = old_endpoint + current_pool = self._make_pool( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + session._pools[host] = current_pool + + stale_pool = HostConnection.__new__(HostConnection) + stale_pool.host = host + stale_pool.endpoint = old_endpoint + stale_pool._session = session + + stale_pool._remove_stale_pool(old_endpoint) + + assert session._get_pool_by_host_identity(host) is current_pool + current_pool.shutdown.assert_not_called() + def test_add_or_renew_pool_tags_pool_with_creation_endpoint(self): host = self._make_host("127.0.0.1") old_endpoint = host.endpoint From 3ec84db3c92d4a38714d5626c75130cb598a2cc2 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 23:18:13 -0400 Subject: [PATCH 12/29] session: refresh stale pool creation on endpoint change --- cassandra/cluster.py | 8 +++++++- tests/unit/test_cluster.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 017d266115..a879e23417 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4141,7 +4141,13 @@ def callback(pool, errors): with self._lock: state = self._get_pool_creation_state(host) if state.creation_epoch is not None: - return state.future + with host.lock: + endpoint_changed = not self._endpoints_match( + host.endpoint, state.endpoint) + if not endpoint_changed: + return state.future + self._invalidate_pool_creation( + host, expected_endpoint=state.endpoint) creation_epoch = state.advance() state.creation_epoch = creation_epoch diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index e07ad72edb..ed5934833c 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -544,6 +544,40 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_add_or_renew_pool_invalidates_creation_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + _, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + old_future = session.add_or_renew_pool( + host, is_host_addition=False) + host.endpoint = new_endpoint + + new_future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert new_future is not old_future + assert len(executor.submissions) == 2 + + executor.run_next() + executor.run_next() + + assert old_future.result() is False + assert new_future.result() is True + assert created_pools[0].endpoint == old_endpoint + assert created_pools[1].endpoint == new_endpoint + created_pools[0].shutdown.assert_called_once_with() + created_pools[1].shutdown.assert_not_called() + assert session._pools[host] is created_pools[1] + def test_pool_creation_publishes_before_endpoint_lock_is_released(self): host = self._make_host("127.0.0.1") new_endpoint = DefaultEndPoint("127.0.0.2") From 117d3450e6e2bbb0becfe1657fab9bc3f8c48da5 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 06:39:40 -0400 Subject: [PATCH 13/29] pool: ignore reassigned stale host failures --- cassandra/pool.py | 57 +++++++++++++++++++++---- tests/unit/test_host_connection_pool.py | 33 ++++++++++++++ 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/cassandra/pool.py b/cassandra/pool.py index c96016feba..e355d88d3d 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -64,6 +64,35 @@ def _endpoints_match(cluster, endpoint, expected_endpoint): return endpoint == expected_endpoint +def _host_is_current_for_endpoint(cluster, host, expected_endpoint): + metadata = getattr(cluster, "metadata", None) + hosts = getattr(metadata, "_hosts", None) + host_id_by_endpoint = getattr(metadata, "_host_id_by_endpoint", None) + if not isinstance(hosts, dict) or not isinstance(host_id_by_endpoint, dict): + return True + + def check_mapping(): + try: + mapped_host_id = host_id_by_endpoint.get(expected_endpoint) + except TypeError: + return True + if mapped_host_id is None: + return hosts.get(getattr(host, "host_id", None)) is host + return hosts.get(mapped_host_id) is host + + metadata_lock = getattr(metadata, "_hosts_lock", None) + if metadata_lock is not None: + with metadata_lock: + return check_mapping() + return check_mapping() + + +def _host_matches_expected_endpoint(cluster, host, expected_endpoint): + current_endpoint = _current_host_endpoint(host) + return (_endpoints_match(cluster, current_endpoint, expected_endpoint) and + _host_is_current_for_endpoint(cluster, host, expected_endpoint)) + + @total_ordering class Host(object): """ @@ -585,10 +614,20 @@ def return_connection(self, connection, stream_was_orphaned=False): log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) with self.host.lock: - if not _endpoints_match(self._session.cluster, self.host.endpoint, self.endpoint): + endpoint_matches = _endpoints_match( + self._session.cluster, self.host.endpoint, self.endpoint) + host_is_current = (endpoint_matches and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + self.endpoint)) + if not endpoint_matches: log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", self.host, self.endpoint) stale_endpoint_failure = True + elif not host_is_current: + log.debug("Ignoring stale connection failure for host %s; endpoint reassigned from %s", + self.host, self.endpoint) + stale_endpoint_failure = True else: is_down = self.host.signal_connection_failure(connection.last_error) connection.signaled_error = True @@ -653,8 +692,8 @@ def _remove_stale_pool(self, expected_endpoint): def _replace(self, connection): expected_endpoint = self.endpoint - current_endpoint = _current_host_endpoint(self.host) - if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", self.host, expected_endpoint) self._remove_stale_pool(expected_endpoint) @@ -694,8 +733,8 @@ def _replace(self, connection): self._stream_available_condition.notify() return - current_endpoint = _current_host_endpoint(self.host) - if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", self.host, expected_endpoint) replacement_connection.close() @@ -820,8 +859,8 @@ def _open_connection_to_missing_shard(self, shard_id): if self.is_shutdown: return - current_endpoint = _current_host_endpoint(self.host) - if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", self.host, expected_endpoint) self._remove_stale_pool(expected_endpoint) @@ -843,8 +882,8 @@ def _open_connection_to_missing_shard(self, shard_id): else: conn = self._session.cluster.connection_factory(expected_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) - current_endpoint = _current_host_endpoint(self.host) - if not _endpoints_match(self._session.cluster, current_endpoint, expected_endpoint): + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", self.host, expected_endpoint) conn.close() diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 3dbde237a5..bee25feb63 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -25,6 +25,7 @@ from cassandra.cluster import Cluster, Session, ShardAwareOptions from cassandra.connection import ClientRoutesEndPoint, Connection, DefaultEndPoint +from cassandra.metadata import Metadata from cassandra.pool import HostConnection from cassandra.pool import Host, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy @@ -263,6 +264,38 @@ def test_return_defunct_connection_after_client_route_endpoint_port_swap_is_igno conn.close.assert_called_once_with() assert not pool.is_shutdown + def test_return_defunct_connection_after_endpoint_reassignment_is_ignored(self): + endpoint = DefaultEndPoint('127.0.0.1') + stale_host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + stale_host.signal_connection_failure = Mock(return_value=True) + + session = self.make_session() + session.remove_pool.return_value = None + session.cluster.metadata = Metadata() + session.cluster.metadata.add_or_return_host(replacement_host) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100, + signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(stale_host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + conn.is_defunct = True + pool.return_connection(conn) + + stale_host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + session.remove_pool.assert_called_once_with( + stale_host, expected_host=stale_host, + expected_endpoint=endpoint, expected_pool=pool) + conn.close.assert_called_once_with() + assert not pool.is_shutdown + def test_host_instantiations(self): """ Ensure Host fails if not initialized properly From 1e2a52f33555c12b30d56ff53a36b72b4e722907 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 07:21:03 -0400 Subject: [PATCH 14/29] cluster: avoid lock inversion on connection failure --- cassandra/cluster.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a879e23417..9031e742fb 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2813,6 +2813,7 @@ def _has_non_retryable_auth_failure(cls, host): def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False, expected_endpoint=None): + signal_down = False with host.lock: if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", @@ -2838,9 +2839,11 @@ def signal_connection_failure(self, host, connection_exc, is_host_addition, self._set_non_retryable_auth_failure(host, True) return is_down if is_down: - self.on_down( - host, is_host_addition, expect_host_to_be_down, - expected_endpoint=expected_endpoint) + signal_down = True + if signal_down: + self.on_down( + host, is_host_addition, expect_host_to_be_down, + expected_endpoint=expected_endpoint) return is_down def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): From 6d99ba2944f78404b412c685524180050ad2d2ad Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 08:39:36 -0400 Subject: [PATCH 15/29] cluster: discount current endpoint down events --- cassandra/cluster.py | 39 ++++++++++++++++++++++---------------- tests/unit/test_cluster.py | 22 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9031e742fb..5e75a13da2 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2591,25 +2591,32 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, if self.is_shutdown: return - if (self._discount_down_events and expected_endpoint is None and + if (self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED): with host.lock: host_endpoint = host.endpoint - connected = False - for session in tuple(self.sessions): - # Host equality is endpoint-based; scan by identity to avoid - # hiding the live pool behind a stale equal key. Do not hold - # host.lock while taking session._lock; update_created_pools() - # takes the locks in the opposite order. - pool = session._get_pool_by_host_identity( - host, expected_endpoint=host_endpoint) - if pool is not None and pool.open_count > 0: - connected = True - break - if connected: - with host.lock: - if self._endpoints_match(host.endpoint, host_endpoint): - return + discount_endpoint = host_endpoint + if expected_endpoint is not None: + if self._endpoints_match(host_endpoint, expected_endpoint): + discount_endpoint = expected_endpoint + else: + discount_endpoint = None + if discount_endpoint is not None: + connected = False + for session in tuple(self.sessions): + # Host equality is endpoint-based; scan by identity to avoid + # hiding the live pool behind a stale equal key. Do not hold + # host.lock while taking session._lock; update_created_pools() + # takes the locks in the opposite order. + pool = session._get_pool_by_host_identity( + host, expected_endpoint=discount_endpoint) + if pool is not None and pool.open_count > 0: + connected = True + break + if connected: + with host.lock: + if self._endpoints_match(host.endpoint, discount_endpoint): + return with host.lock: if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index ed5934833c..9bffd85d87 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1069,6 +1069,28 @@ def run_on_down(): assert not runner.is_alive() assert blocked_on_host == [False] + def test_discount_down_event_applies_to_current_expected_endpoint(self): + host = self._make_host() + host.set_up() + endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=endpoint) + + assert host.is_up + cluster.on_down_potentially_blocking.assert_not_called() + @staticmethod def _state(cluster, host): return cluster._get_host_liveness_state(host) From deaab93ba0a0fcab18526e5117c9b810130e607b Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 08:49:27 -0400 Subject: [PATCH 16/29] tests: cover connection failure lock ordering --- tests/unit/test_cluster.py | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 9bffd85d87..7576725f2a 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1069,6 +1069,65 @@ def run_on_down(): assert not runner.is_alive() assert blocked_on_host == [False] + def test_signal_connection_failure_does_not_hold_host_lock_while_scanning_pools(self): + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = old_endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + # Mutating the endpoint after insertion forces an identity scan under + # session._lock. signal_connection_failure() must not keep host.lock + # held while it delegates to the down path that performs that scan. + host.endpoint = DefaultEndPoint("127.0.0.2") + entered_distance = Event() + release_distance = Event() + blocked_on_host = [] + thread_errors = [] + + def distance(_host): + entered_distance.set() + release_distance.wait(1) + return HostDistance.LOCAL + + def hold_session_then_try_host(): + with session._lock: + entered_distance.wait(1) + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + release_distance.set() + + def run_signal_connection_failure(): + try: + Cluster.signal_connection_failure( + cluster, host, ConnectionException("failed"), + is_host_addition=False) + except Exception as exc: + thread_errors.append(exc) + + cluster.profile_manager.distance.side_effect = distance + worker = Thread(target=hold_session_then_try_host) + runner = Thread(target=run_signal_connection_failure) + worker.start() + runner.start() + worker.join(2) + runner.join(2) + + assert not thread_errors + assert not runner.is_alive() + assert blocked_on_host == [False] + def test_discount_down_event_applies_to_current_expected_endpoint(self): host = self._make_host() host.set_up() From 2ac6f30f7347aa6e4cf32e4b96d25323f0f372ed Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 09:37:03 -0400 Subject: [PATCH 17/29] response-future: reject stale endpoint pool borrow --- cassandra/cluster.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e75a13da2..4ae5749e6f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5687,6 +5687,7 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message + expected_endpoint = None if isinstance(host, Host): with host.lock: expected_endpoint = host.endpoint @@ -5694,6 +5695,16 @@ def _query(self, host, message=None, cb=None): host, expected_endpoint=expected_endpoint) else: pool = self.session._get_pool_by_host_identity(host) + + if pool and expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None @@ -5710,8 +5721,23 @@ def _query(self, host, message=None, cb=None): connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table) else: connection, request_id = pool.borrow_connection(timeout=2.0) + + if expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + try: + pool.return_connection(connection) + finally: + connection = None + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + self._connection = connection self._connection_pool = pool + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: From 24cdc42ff108c1473d5ac7e7e24b6e9a5a9a98db Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 10:11:20 -0400 Subject: [PATCH 18/29] session: discard stale replacement-host pools --- cassandra/cluster.py | 49 +++++++++++++++++++++++++------------- tests/unit/test_cluster.py | 28 ++++++++++++++++++++++ 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 4ae5749e6f..5a86fd424a 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4105,24 +4105,39 @@ def callback(pool, errors): metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) target_host = metadata_host if metadata_host is not None else host - target_host_matches = False - for pool_host in tuple(retained_pools): - if pool_host is target_host: - target_host_matches = True - elif pool_host == target_host: - previous_pools.append(retained_pools.pop(pool_host)) - - if target_host_matches: - reuse_existing_pool = True + target_endpoint_changed = False + if target_host is not host: + with target_host.lock: + target_endpoint_changed = not self._endpoints_match( + target_host.endpoint, creation_endpoint) + + if target_endpoint_changed: + log.debug( + "Discarding stale connection pool for host %s; " + "metadata host endpoint changed from %s", + host, creation_endpoint) + self._invalidate_pool_creation( + host, expected_endpoint=creation_endpoint) + discard_pool = True else: - source_host = new_pool.host - if (source_host is not target_host and - target_host.sharding_info is None): - target_host.sharding_info = source_host.sharding_info - new_pool.host = target_host - retained_pools[target_host] = new_pool - self._pools = retained_pools - self._clear_pool_creation(host, creation_epoch) + target_host_matches = False + for pool_host in tuple(retained_pools): + if pool_host is target_host: + target_host_matches = True + elif pool_host == target_host: + previous_pools.append(retained_pools.pop(pool_host)) + + if target_host_matches: + reuse_existing_pool = True + else: + source_host = new_pool.host + if (source_host is not target_host and + target_host.sharding_info is None): + target_host.sharding_info = source_host.sharding_info + new_pool.host = target_host + retained_pools[target_host] = new_pool + self._pools = retained_pools + self._clear_pool_creation(host, creation_epoch) if reuse_existing_pool: log.debug("Reusing existing connection pool for host %s", host) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 7576725f2a..15f7e1f301 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -26,6 +26,7 @@ from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT from cassandra.connection import ClientRoutesEndPoint, ConnectionException, DefaultEndPoint, SniEndPoint +from cassandra.metadata import Metadata from cassandra.pool import Host, HostConnection, _HostReconnectionHandler from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -664,6 +665,33 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_stale_host_pool_creation_does_not_publish_to_replacement_host(self): + host_id = uuid.uuid4() + stale_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, + host_id=host_id) + replacement_host = Host(DefaultEndPoint("127.0.0.2"), + SimpleConvictionPolicy, host_id=host_id) + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + cluster.metadata = Metadata() + cluster.metadata.add_or_return_host(replacement_host) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + def test_remove_pool_expected_host_mismatch_invalidates_stale_creation(self): stale_host = self._make_host("127.0.0.1") replacement_host = self._make_host("127.0.0.1") From 13d02c4110ab3b552cf87ff11d73a79f4dbe8233 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 10:35:02 -0400 Subject: [PATCH 19/29] pool: discard stale replacement shard connections --- cassandra/pool.py | 38 ++++++++++- tests/unit/test_host_connection_pool.py | 88 ++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 3 deletions(-) diff --git a/cassandra/pool.py b/cassandra/pool.py index e355d88d3d..d3a71d2161 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -754,13 +754,30 @@ def _replace(self, connection): self._session.submit(self._replace, connection) return + stale_endpoint = False with self._lock: if self.is_shutdown: replacement_connection.close() self._is_replacing = False return - self._connections[replacement_connection.features.shard_id] = replacement_connection + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, expected_endpoint)) + if not stale_endpoint: + self._connections[replacement_connection.features.shard_id] = replacement_connection self._is_replacing = False + if stale_endpoint: + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + replacement_connection.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return with self._stream_available_condition: self._stream_available_condition.notify() @@ -931,11 +948,28 @@ def _open_connection_to_missing_shard(self, shard_id): ) if self._keyspace: conn.set_keyspace_blocking(self._keyspace) - self._connections[conn.features.shard_id] = conn + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + expected_endpoint)) + if not stale_endpoint: + self._connections[conn.features.shard_id] = conn if is_shutdown: conn.close() return + if stale_endpoint: + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + conn.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return if old_conn is not None: remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index bee25feb63..e40df1453d 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -20,7 +20,7 @@ from cassandra.shard_info import _ShardingInfo import unittest -from threading import Thread, Event, Lock +from threading import Thread, Event, Lock, Condition from unittest.mock import Mock, NonCallableMagicMock, MagicMock from cassandra.cluster import Cluster, Session, ShardAwareOptions @@ -415,3 +415,89 @@ def test_replace_retries_when_replacement_keyspace_set_fails(self): submitted_fn, submitted_connection = session.submit.call_args.args assert submitted_fn == pool._replace assert submitted_connection is initial_connection + + def test_replace_discards_replacement_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + initial_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.side_effect = [ + initial_connection, replacement_connection] + + pool = HostConnection(host, HostDistance.LOCAL, session) + pool._is_replacing = True + + pool._replace(initial_connection) + + replacement_connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert not pool._is_replacing + + def test_missing_shard_discards_connection_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.sharding_info = _ShardingInfo( + shard_id=0, shards_count=1, partitioner='', + sharding_algorithm='', sharding_ignore_msb=0, + shard_aware_port='', shard_aware_port_ssl='') + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster.ssl_options = None + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.return_value = connection + + pool = HostConnection.__new__(HostConnection) + pool.host = host + pool.endpoint = old_endpoint + pool.host_distance = HostDistance.LOCAL + pool.is_shutdown = False + pool._session = session + pool._lock = Lock() + pool._stream_available_condition = Condition(Lock()) + pool._connections = {} + pool._pending_connections = [] + pool._connecting = {0} + pool._excess_connections = set() + pool._trash = set() + pool._shard_connections_futures = [] + pool._keyspace = 'ks' + pool.advanced_shardaware_block_until = 0 + pool.tablets_routing_v1 = False + + pool._open_connection_to_missing_shard(0) + + connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert pool._connecting == set() From c417c58677ce7731b0b01f0b1d82818b84c6eca7 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 11:15:17 -0400 Subject: [PATCH 20/29] cluster: preserve replacement endpoint down handling --- cassandra/cluster.py | 89 +++++++++++++++++++++++++++--- tests/unit/test_cluster.py | 110 +++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 7 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5a86fd424a..8c6667730e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -333,13 +333,17 @@ class _HostLivenessState(_EventFenceState): _UP = "up" _DOWN = "down" _PENDING_UP = "pending_up" + _PENDING_DOWN = "pending_down" - __slots__ = ("up_endpoint", "pending_up_endpoint") + __slots__ = ("up_endpoint", "down_endpoint", + "pending_up_endpoint", "pending_down_endpoint") def __init__(self): _EventFenceState.__init__(self) self.up_endpoint = None + self.down_endpoint = None self.pending_up_endpoint = None + self.pending_down_endpoint = None @property def up_epoch(self): @@ -365,6 +369,14 @@ def pending_up_epoch(self): def pending_up_epoch(self, epoch): self._set_or_clear_event(self._PENDING_UP, epoch) + @property + def pending_down_epoch(self): + return self.get_event(self._PENDING_DOWN) + + @pending_down_epoch.setter + def pending_down_epoch(self, epoch): + self._set_or_clear_event(self._PENDING_DOWN, epoch) + class _PoolCreationState(_EventFenceState): _CREATE = "create" @@ -2193,7 +2205,42 @@ def _handle_pending_node_up(self, host, pending_up): def _clear_down_handling(self, host, down_epoch=None): state = self._get_host_liveness_state(host) - return state.clear_event(_HostLivenessState._DOWN, down_epoch) + if state.clear_event(_HostLivenessState._DOWN, down_epoch): + state.down_endpoint = None + return True + return False + + def _clear_pending_down(self, state): + state.pending_down_epoch = None + state.pending_down_endpoint = None + + def _pop_pending_node_down_if_ready(self, host): + state = self._get_host_liveness_state(host) + if state.pending_down_epoch is None: + state.pending_down_endpoint = None + return None + if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + return None + + pending_down_epoch = state.pending_down_epoch + pending_down_endpoint = state.pending_down_endpoint + if state.epoch != pending_down_epoch: + self._clear_pending_down(state) + return None + + self._clear_pending_down(state) + return pending_down_epoch, pending_down_endpoint + + def _handle_pending_node_down(self, host, pending_down, is_host_addition): + if pending_down is None: + return False + _pending_down_epoch, pending_down_endpoint = pending_down + log.debug("Handling queued down status of node %s", host) + self.on_down( + host, is_host_addition=is_host_addition, + expect_host_to_be_down=True, + expected_endpoint=pending_down_endpoint) + return True def _finish_superseded_up_handling(self, host, up_epoch, expected_endpoint=None): self._cleanup_superseded_up_handling( @@ -2515,6 +2562,7 @@ def on_down_potentially_blocking( return down_epoch = state.epoch state.down_epoch = down_epoch + state.down_endpoint = expected_endpoint elif not owns_reserved_down_handling: log.debug("Ignoring stale down handling for host %s", host) return @@ -2576,12 +2624,16 @@ def on_down_potentially_blocking( else: log.debug("Not starting reconnector for removed host %s", host) finally: + pending_down = None pending_up_epoch = None with host.lock: if down_epoch is not None and self._clear_down_handling(host, down_epoch): - pending_up_epoch = self._pop_pending_node_up_if_ready(host) + pending_down = self._pop_pending_node_down_if_ready(host) + if pending_down is None: + pending_up_epoch = self._pop_pending_node_up_if_ready(host) - self._handle_pending_node_up(host, pending_up_epoch) + if not self._handle_pending_node_down(host, pending_down, is_host_addition): + self._handle_pending_node_up(host, pending_up_epoch) def on_down(self, host, is_host_addition, expect_host_to_be_down=False, expected_endpoint=None): @@ -2591,7 +2643,8 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, if self.is_shutdown: return - if (self._discount_down_events and + if (not expect_host_to_be_down and + self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED): with host.lock: host_endpoint = host.endpoint @@ -2626,6 +2679,20 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, was_up = host.is_up state = self._get_host_liveness_state(host) + pending_down_endpoint = None + if state.down_endpoint is not None: + target_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint + if not self._endpoints_match(state.down_endpoint, target_endpoint): + pending_down_endpoint = target_endpoint + + if pending_down_endpoint is not None: + state.advance() + state.pending_up_epoch = None + state.pending_up_endpoint = None + state.pending_down_epoch = state.epoch + state.pending_down_endpoint = pending_down_endpoint + host.set_down() + return if not expect_host_to_be_down: if was_up is False: @@ -2639,6 +2706,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, state.advance() state.pending_up_epoch = None state.pending_up_endpoint = None + self._clear_pending_down(state) host.set_down() down_epoch = state.epoch if state.down_epoch is not None: @@ -2648,16 +2716,21 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, state.up_epoch is None): return state.down_epoch = down_epoch + state.down_endpoint = expected_endpoint log.warning("Host %s has been marked down", host) future = self.on_down_potentially_blocking( host, is_host_addition, down_epoch, expected_endpoint) if future is None: + pending_down = None pending_up_epoch = None with host.lock: if self._clear_down_handling(host, down_epoch): - pending_up_epoch = self._pop_pending_node_up_if_ready(host) - self._handle_pending_node_up(host, pending_up_epoch) + pending_down = self._pop_pending_node_down_if_ready(host) + if pending_down is None: + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + if not self._handle_pending_node_down(host, pending_down, is_host_addition): + self._handle_pending_node_up(host, pending_up_epoch) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2742,9 +2815,11 @@ def on_remove(self, host): state.advance() state.pending_up_epoch = None state.pending_up_endpoint = None + self._clear_pending_down(state) state.up_epoch = None state.up_endpoint = None state.down_epoch = None + state.down_endpoint = None host.set_down() self._set_non_retryable_auth_failure(host, False) self.profile_manager.on_remove(host) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 15f7e1f301..34d32a6317 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1178,6 +1178,29 @@ def test_discount_down_event_applies_to_current_expected_endpoint(self): assert host.is_up cluster.on_down_potentially_blocking.assert_not_called() + def test_forced_down_is_not_discounted_by_connected_pool(self): + host = self._make_host() + host.set_up() + endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True, expected_endpoint=endpoint) + + assert not host.is_up + cluster.on_down_potentially_blocking.assert_called_once_with( + host, False, ANY, endpoint) + @staticmethod def _state(cluster, host): return cluster._get_host_liveness_state(host) @@ -1940,6 +1963,93 @@ def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): assert state.up_epoch is None assert state.pending_up_epoch is None + def test_down_for_replacement_endpoint_during_pending_old_down_is_handled(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=old_endpoint) + state = self._state(cluster, host) + assert state.down_epoch == state.epoch + + host.endpoint = new_endpoint + Cluster.on_up(cluster, host) + assert state.pending_up_epoch == state.epoch + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=new_endpoint) + + executor.run_next() + + assert len(executor.submissions) == 1 + executor.run_next() + + session.on_down.assert_any_call( + host, expected_endpoint=old_endpoint) + session.on_down.assert_any_call( + host, expected_endpoint=new_endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=new_endpoint) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_forced_down_for_replacement_endpoint_during_old_down_is_handled(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=old_endpoint) + state = self._state(cluster, host) + assert state.down_epoch == state.epoch + + host.endpoint = new_endpoint + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True, expected_endpoint=new_endpoint) + + executor.run_next() + + assert len(executor.submissions) == 1 + executor.run_next() + + session.on_down.assert_any_call( + host, expected_endpoint=old_endpoint) + session.on_down.assert_any_call( + host, expected_endpoint=new_endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=new_endpoint) + assert not host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + def test_later_down_before_worker_runs_does_not_skip_pool_cleanup(self): executor = _QueuedExecutor() host = self._make_host() From f9f991628ca45e307a15141d0825dc99d1e6edaa Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 11:52:54 -0400 Subject: [PATCH 21/29] response-future: release stream id on stale borrow --- cassandra/cluster.py | 2 + tests/unit/test_response_future.py | 89 ++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8c6667730e..08a285798f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5818,6 +5818,8 @@ def _query(self, host, message=None, cb=None): host.endpoint, expected_endpoint) if endpoint_changed: try: + with connection.lock: + connection.request_ids.append(request_id) pool.return_connection(connection) finally: connection = None diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 07e8f3abdf..8aca91567a 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -38,6 +38,24 @@ class ResponseFutureTests(unittest.TestCase): + class _EndpointSwapOnExitLock(object): + + def __init__(self, host, new_endpoint): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._exits = 0 + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self._host.endpoint = self._new_endpoint + def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] @@ -726,6 +744,77 @@ def test_query_does_not_borrow_stale_pool_after_endpoint_swap(self): stale_pool.borrow_connection.assert_not_called() assert isinstance(rf._errors[host], ConnectionException) + def test_query_rechecks_endpoint_after_pool_lookup_race(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.lock = self._EndpointSwapOnExitLock( + host, DefaultEndPoint('127.0.0.2')) + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + + def test_query_releases_request_id_after_post_borrow_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint('127.0.0.2') + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = old_endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection, lock=RLock(), request_ids=deque([1]), + in_flight=0) + + def borrow_connection(**kwargs): + with connection.lock: + connection.in_flight += 1 + request_id = connection.request_ids.popleft() + host.endpoint = new_endpoint + return connection, request_id + + def return_connection(returned_connection): + with returned_connection.lock: + returned_connection.in_flight -= 1 + + stale_pool.borrow_connection.side_effect = borrow_connection + stale_pool.return_connection.side_effect = return_connection + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.return_connection.assert_called_once_with(connection) + connection.send_msg.assert_not_called() + assert list(connection.request_ids) == [1] + assert connection.in_flight == 0 + assert isinstance(rf._errors[host], ConnectionException) + def test_single_host_query_plan_exhausted_after_one_retry(self): """ Test that when a specific host is provided, the query plan is properly From 9c441bc8219c68b3bd065c53914de6802569279a Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 12:54:13 -0400 Subject: [PATCH 22/29] control-connection: reconnect after stale failure --- cassandra/cluster.py | 6 +++--- tests/unit/test_control_connection.py | 28 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 08a285798f..68832728c5 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5362,10 +5362,10 @@ def _signal_error(self): # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + if self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False, - expected_endpoint=self._connection.endpoint) - return + expected_endpoint=self._connection.endpoint): + return # if the connection is not defunct or the host already left, reconnect # manually diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 0f993c8d7e..3bba2d0a42 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -455,6 +455,34 @@ def test_stale_control_connection_failure_is_endpoint_fenced(self): host, self.connection.last_error, is_host_addition=False, expected_endpoint=old_endpoint) + def test_stale_control_connection_failure_reconnects_when_cluster_ignores_signal(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=host_id) + host.set_up() + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + host.endpoint = new_endpoint + self.cluster.metadata.update_host(host, old_endpoint) + + self.connection.endpoint = old_endpoint + self.connection.is_defunct = True + self.connection.last_error = ConnectionException( + "stale control connection failed", endpoint=old_endpoint) + self.cluster.signal_connection_failure = Mock(return_value=False) + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, self.connection.last_error, is_host_addition=False, + expected_endpoint=old_endpoint) + self.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): """ refresh_nodes_and_tokens uses preloaded results if given for shared table queries From 9f957e3faef0eb7f40ff93653e0b68c30652ed3c Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 14:14:47 -0400 Subject: [PATCH 23/29] cluster: release host lock before down listeners --- cassandra/cluster.py | 10 ++++++---- tests/unit/test_cluster.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 68832728c5..46a5a43f8c 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2592,10 +2592,10 @@ def on_down_potentially_blocking( else: session.on_down(host, expected_endpoint=expected_endpoint) + notify_listeners = False if endpoint_matches: if expected_endpoint is None: - for listener in self.listeners: - listener.on_down(host) + notify_listeners = True else: with host.lock: if not self._endpoints_match(host.endpoint, expected_endpoint): @@ -2603,8 +2603,10 @@ def on_down_potentially_blocking( host, expected_endpoint) endpoint_matches = False else: - for listener in self.listeners: - listener.on_down(host) + notify_listeners = True + if notify_listeners: + for listener in self.listeners: + listener.on_down(host) with host.lock: start_reconnector = (endpoint_matches and diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 34d32a6317..b0f99f593a 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1301,6 +1301,36 @@ def test_reserved_down_handling_after_endpoint_swap_removes_stale_pool(self): cluster._start_reconnector.assert_not_called() assert self._state(cluster, host).down_epoch is None + def test_expected_endpoint_down_listener_is_not_called_under_host_lock(self): + session = Mock() + cluster = self._make_cluster(session=session) + host = self._make_host() + host.set_up() + expected_endpoint = host.endpoint + down_epoch = self._reserve_down_handling(cluster, host) + blocked_on_host = [] + + class Listener(object): + + def on_down(self, _host): + def try_host_lock(): + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + + worker = Thread(target=try_host_lock) + worker.start() + worker.join(1) + + cluster._listeners = set([Listener()]) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch, + expected_endpoint=expected_endpoint) + + self.assertEqual(blocked_on_host, [False]) + def test_endpoint_match_preserves_endpoint_specific_identity(self): proxy_endpoint = SniEndPoint("proxy.example.com", "node-a", port=9042) other_proxy_endpoint = SniEndPoint("proxy.example.com", "node-b", port=9042) From a2a26ca5a83d69e040e2976f85196f542d5f26bc Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 14:55:55 -0400 Subject: [PATCH 24/29] control-connection: preserve policy state on endpoint swap --- cassandra/cluster.py | 17 ++++-- tests/unit/test_control_connection.py | 86 +++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 46a5a43f8c..1dad288efa 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2549,7 +2549,8 @@ def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None, def on_down_potentially_blocking( self, host: Host, is_host_addition: bool, down_epoch: Optional[int] = None, - expected_endpoint: Optional[EndPoint] = None) -> Any: + expected_endpoint: Optional[EndPoint] = None, + profile_manager_already_notified: bool = False) -> Any: pending_up_epoch = None try: down_endpoint = None @@ -2577,7 +2578,8 @@ def on_down_potentially_blocking( host, down_endpoint) endpoint_matches = False else: - self.profile_manager.on_down(host) + if not profile_manager_already_notified: + self.profile_manager.on_down(host) self.control_connection.on_down(host) else: log.debug("Not signalling down for stale down handling on node %s; endpoint changed from %s", @@ -2638,7 +2640,7 @@ def on_down_potentially_blocking( self._handle_pending_node_up(host, pending_up_epoch) def on_down(self, host, is_host_addition, expect_host_to_be_down=False, - expected_endpoint=None): + expected_endpoint=None, profile_manager_already_notified=False): """ Intended for internal use only. """ @@ -2722,7 +2724,8 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, log.warning("Host %s has been marked down", host) future = self.on_down_potentially_blocking( - host, is_host_addition, down_epoch, expected_endpoint) + host, is_host_addition, down_epoch, expected_endpoint, + profile_manager_already_notified) if future is None: pending_down = None pending_up_epoch = None @@ -5023,7 +5026,8 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, old_endpoint = host.endpoint self._cluster.on_down( host, is_host_addition=False, expect_host_to_be_down=True, - expected_endpoint=old_endpoint) + expected_endpoint=old_endpoint, + profile_manager_already_notified=True) with host.lock: if not self._cluster._endpoints_match(host.endpoint, old_endpoint): @@ -5031,9 +5035,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, "endpoint changed to %s", old_endpoint, endpoint, host_id, host.endpoint) continue + self._cluster.profile_manager.on_down(host) host.endpoint = endpoint self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) + self._cluster.on_up(host, expected_endpoint=endpoint) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 3bba2d0a42..f048201b66 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,7 +15,7 @@ import unittest import uuid -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType @@ -134,7 +134,7 @@ def on_up(self, host, expected_endpoint=None): pass def on_down(self, host, is_host_addition, expect_host_to_be_down=False, - expected_endpoint=None): + expected_endpoint=None, profile_manager_already_notified=False): self.down_host = host self.down_expected_endpoint = expected_endpoint @@ -154,6 +154,45 @@ def _node_meta_results(local_results, peer_results): return peer_response, local_response +class RunOnResultFuture(Future): + + def __init__(self, fn, args, kwargs): + Future.__init__(self) + self._fn = fn + self._args = args + self._kwargs = kwargs + + def _run_once(self): + if self.done(): + return + try: + self.set_result(self._fn(*self._args, **self._kwargs)) + except Exception as exc: + self.set_exception(exc) + + def result(self, timeout=None): + self._run_once() + return Future.result(self, timeout) + + +class RunOnResultExecutor(object): + + def __init__(self): + self.futures = [] + + def submit(self, fn, *args, **kwargs): + future = RunOnResultFuture(fn, args, kwargs) + self.futures.append(future) + return future + + def run_all(self): + for future in tuple(self.futures): + future._run_once() + + def shutdown(self): + self.run_all() + + class MockConnection(object): is_defunct = False @@ -423,8 +462,47 @@ def test_change_client_route_endpoint_when_only_port_changes(self): assert Cluster._endpoints_match(host.endpoint, new_endpoint) self.cluster.on_down.assert_called_once_with( host, is_host_addition=False, expect_host_to_be_down=True, - expected_endpoint=old_endpoint) - self.cluster.on_up.assert_called_once_with(host) + expected_endpoint=old_endpoint, + profile_manager_already_notified=True) + self.cluster.on_up.assert_called_once_with( + host, expected_endpoint=new_endpoint) + + def test_endpoint_change_preserves_live_policy_hosts_when_down_handler_runs_late(self): + old_endpoint = DefaultEndPoint("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + host_id = uuid.uuid4() + policy = RoundRobinPolicy() + cluster = Cluster(contact_points=[], load_balancing_policy=policy) + original_executor = cluster.executor + cluster.executor = RunOnResultExecutor() + cluster._start_reconnector = Mock() + + try: + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.endpoint_factory = Mock() + cluster.endpoint_factory.create.return_value = new_endpoint + cluster.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "peer", "data_center", "rack", "host_id"], + [["127.0.0.2", "127.0.0.2", "dc1", "rack1", host_id]])) + + cluster.control_connection._refresh_node_list_and_token_map( + Mock(), preloaded_results=preloaded_results) + cluster.executor.run_all() + + assert list(policy._live_hosts) == [host] + assert host in policy._live_hosts + finally: + cluster.executor = original_executor + cluster.shutdown() def test_stale_control_connection_failure_is_endpoint_fenced(self): host_id = uuid.uuid4() From 37a81461db5a1ea1691ccd435ed2e1f683769000 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 15:38:17 -0400 Subject: [PATCH 25/29] pool: ignore failures from replaced pools --- cassandra/pool.py | 65 ++++++++++++++++++------- tests/unit/test_host_connection_pool.py | 39 +++++++++++++++ 2 files changed, 87 insertions(+), 17 deletions(-) diff --git a/cassandra/pool.py b/cassandra/pool.py index d3a71d2161..35cb107ef5 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -613,23 +613,8 @@ def return_connection(self, connection, stream_was_orphaned=False): if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - with self.host.lock: - endpoint_matches = _endpoints_match( - self._session.cluster, self.host.endpoint, self.endpoint) - host_is_current = (endpoint_matches and - _host_is_current_for_endpoint( - self._session.cluster, self.host, - self.endpoint)) - if not endpoint_matches: - log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", - self.host, self.endpoint) - stale_endpoint_failure = True - elif not host_is_current: - log.debug("Ignoring stale connection failure for host %s; endpoint reassigned from %s", - self.host, self.endpoint) - stale_endpoint_failure = True - else: - is_down = self.host.signal_connection_failure(connection.last_error) + is_down, stale_endpoint_failure = \ + self._signal_connection_failure_if_current(connection) connection.signaled_error = True if stale_endpoint_failure: @@ -683,6 +668,52 @@ def on_orphaned_stream_released(self): with self._stream_available_condition: self._stream_available_condition.notify() + def _is_current_pool_for_endpoint_locked(self, expected_endpoint): + pools = getattr(self._session, "_pools", None) + if not isinstance(pools, dict): + return True + + for pool_host, host_pool in pools.items(): + if pool_host is not self.host or host_pool is not self: + continue + pool_endpoint = getattr(host_pool, 'endpoint', None) + if pool_endpoint is None: + pool_endpoint = host_pool.host.endpoint + return _endpoints_match( + self._session.cluster, pool_endpoint, expected_endpoint) + return False + + def _signal_connection_failure_if_current(self, connection): + session_lock = getattr(self._session, "_lock", None) + if session_lock is None: + with self.host.lock: + return self._signal_connection_failure_locked(connection) + + with session_lock: + with self.host.lock: + if not self._is_current_pool_for_endpoint_locked(self.endpoint): + log.debug("Ignoring stale connection failure for host %s; pool was replaced for %s", + self.host, self.endpoint) + return False, True + return self._signal_connection_failure_locked(connection) + + def _signal_connection_failure_locked(self, connection): + endpoint_matches = _endpoints_match( + self._session.cluster, self.host.endpoint, self.endpoint) + host_is_current = (endpoint_matches and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + self.endpoint)) + if not endpoint_matches: + log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", + self.host, self.endpoint) + return False, True + if not host_is_current: + log.debug("Ignoring stale connection failure for host %s; endpoint reassigned from %s", + self.host, self.endpoint) + return False, True + return self.host.signal_connection_failure(connection.last_error), False + def _remove_stale_pool(self, expected_endpoint): future = self._session.remove_pool( self.host, expected_host=self.host, diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index e40df1453d..0a7c6d07eb 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -296,6 +296,45 @@ def test_return_defunct_connection_after_endpoint_reassignment_is_ignored(self): conn.close.assert_called_once_with() assert not pool.is_shutdown + def test_return_defunct_connection_from_removed_pool_after_endpoint_flip_back_is_ignored(self): + endpoint = DefaultEndPoint('127.0.0.1') + host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.signal_connection_failure = Mock(return_value=True) + + session = self.make_session() + session._lock = Lock() + session._pools = {} + session.remove_pool.return_value = None + session.cluster.metadata = Metadata() + session.cluster.metadata.add_or_return_host(host) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100, + signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + stale_pool = self.PoolImpl(host, HostDistance.LOCAL, session) + stale_pool.borrow_connection(timeout=0.01) + + # The old pool was already removed, then the host endpoint flipped back + # to the same endpoint and a replacement pool became current. + current_pool = Mock() + current_pool.host = host + current_pool.endpoint = endpoint + session._pools[host] = current_pool + + conn.is_defunct = True + stale_pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=endpoint, + expected_pool=stale_pool) + conn.close.assert_called_once_with() + assert not stale_pool.is_shutdown + def test_host_instantiations(self): """ Ensure Host fails if not initialized properly From e8be8e5558f9345779af3509e5719bf02469677f Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 16:18:18 -0400 Subject: [PATCH 26/29] control-connection: reconnect after late endpoint swap --- cassandra/cluster.py | 18 +++++++---- tests/unit/test_cluster.py | 2 +- tests/unit/test_control_connection.py | 44 +++++++++++++++++++++++++-- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 1dad288efa..4b8f9933d1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2550,7 +2550,8 @@ def on_down_potentially_blocking( self, host: Host, is_host_addition: bool, down_epoch: Optional[int] = None, expected_endpoint: Optional[EndPoint] = None, - profile_manager_already_notified: bool = False) -> Any: + profile_manager_already_notified: bool = False, + control_connection_already_notified: bool = False) -> Any: pending_up_epoch = None try: down_endpoint = None @@ -2580,7 +2581,8 @@ def on_down_potentially_blocking( else: if not profile_manager_already_notified: self.profile_manager.on_down(host) - self.control_connection.on_down(host) + if not control_connection_already_notified: + self.control_connection.on_down(host) else: log.debug("Not signalling down for stale down handling on node %s; endpoint changed from %s", host, expected_endpoint) @@ -2640,7 +2642,8 @@ def on_down_potentially_blocking( self._handle_pending_node_up(host, pending_up_epoch) def on_down(self, host, is_host_addition, expect_host_to_be_down=False, - expected_endpoint=None, profile_manager_already_notified=False): + expected_endpoint=None, profile_manager_already_notified=False, + control_connection_already_notified=False): """ Intended for internal use only. """ @@ -2725,7 +2728,8 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, future = self.on_down_potentially_blocking( host, is_host_addition, down_epoch, expected_endpoint, - profile_manager_already_notified) + profile_manager_already_notified, + control_connection_already_notified) if future is None: pending_down = None pending_up_epoch = None @@ -5024,10 +5028,13 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, reconnector.cancel() with host.lock: old_endpoint = host.endpoint + self._cluster.profile_manager.on_down(host) + self.on_down(host) self._cluster.on_down( host, is_host_addition=False, expect_host_to_be_down=True, expected_endpoint=old_endpoint, - profile_manager_already_notified=True) + profile_manager_already_notified=True, + control_connection_already_notified=True) with host.lock: if not self._cluster._endpoints_match(host.endpoint, old_endpoint): @@ -5035,7 +5042,6 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, "endpoint changed to %s", old_endpoint, endpoint, host_id, host.endpoint) continue - self._cluster.profile_manager.on_down(host) host.endpoint = endpoint self._cluster.metadata.update_host(host, old_endpoint) self._cluster.on_up(host, expected_endpoint=endpoint) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index b0f99f593a..94ff764e0b 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1199,7 +1199,7 @@ def test_forced_down_is_not_discounted_by_connected_pool(self): assert not host.is_up cluster.on_down_potentially_blocking.assert_called_once_with( - host, False, ANY, endpoint) + host, False, ANY, endpoint, False, False) @staticmethod def _state(cluster, host): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index f048201b66..560d5290bf 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -134,7 +134,8 @@ def on_up(self, host, expected_endpoint=None): pass def on_down(self, host, is_host_addition, expect_host_to_be_down=False, - expected_endpoint=None, profile_manager_already_notified=False): + expected_endpoint=None, profile_manager_already_notified=False, + control_connection_already_notified=False): self.down_host = host self.down_expected_endpoint = expected_endpoint @@ -463,7 +464,8 @@ def test_change_client_route_endpoint_when_only_port_changes(self): self.cluster.on_down.assert_called_once_with( host, is_host_addition=False, expect_host_to_be_down=True, expected_endpoint=old_endpoint, - profile_manager_already_notified=True) + profile_manager_already_notified=True, + control_connection_already_notified=True) self.cluster.on_up.assert_called_once_with( host, expected_endpoint=new_endpoint) @@ -504,6 +506,44 @@ def test_endpoint_change_preserves_live_policy_hosts_when_down_handler_runs_late cluster.executor = original_executor cluster.shutdown() + def test_endpoint_change_reconnects_control_connection_when_down_handler_runs_late(self): + old_endpoint = DefaultEndPoint("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + host_id = uuid.uuid4() + cluster = Cluster(contact_points=[]) + original_executor = cluster.executor + cluster.executor = RunOnResultExecutor() + cluster._start_reconnector = Mock() + cluster.control_connection._connection = Mock(endpoint=old_endpoint) + cluster.control_connection.reconnect = Mock() + + try: + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.endpoint_factory = Mock() + cluster.endpoint_factory.create.return_value = new_endpoint + cluster.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "peer", "data_center", "rack", "host_id"], + [["127.0.0.2", "127.0.0.2", "dc1", "rack1", host_id]])) + + cluster.control_connection._refresh_node_list_and_token_map( + Mock(), preloaded_results=preloaded_results) + cluster.executor.run_all() + + cluster.control_connection.reconnect.assert_called_once_with() + finally: + cluster.control_connection._connection = None + cluster.executor = original_executor + cluster.shutdown() + def test_stale_control_connection_failure_is_endpoint_fenced(self): host_id = uuid.uuid4() old_endpoint = ClientRoutesEndPoint( From 788e878ec031028e08d60aa392b420e50f79c6db Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 17:15:49 -0400 Subject: [PATCH 27/29] cluster: preserve endpoint for queued down handling --- cassandra/cluster.py | 10 ++--- tests/unit/test_cluster.py | 75 ++++++++++++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 4b8f9933d1..9da2835222 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2686,11 +2686,11 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, was_up = host.is_up state = self._get_host_liveness_state(host) + down_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint pending_down_endpoint = None if state.down_endpoint is not None: - target_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint - if not self._endpoints_match(state.down_endpoint, target_endpoint): - pending_down_endpoint = target_endpoint + if not self._endpoints_match(state.down_endpoint, down_endpoint): + pending_down_endpoint = down_endpoint if pending_down_endpoint is not None: state.advance() @@ -2723,11 +2723,11 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False, state.up_epoch is None): return state.down_epoch = down_epoch - state.down_endpoint = expected_endpoint + state.down_endpoint = down_endpoint log.warning("Host %s has been marked down", host) future = self.on_down_potentially_blocking( - host, is_host_addition, down_epoch, expected_endpoint, + host, is_host_addition, down_epoch, down_endpoint, profile_manager_already_notified, control_connection_already_notified) if future is None: diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 94ff764e0b..6495442b6e 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1237,6 +1237,31 @@ def test_stale_down_handling_is_ignored_after_host_is_up(self): listener.on_down.assert_not_called() cluster._start_reconnector.assert_not_called() + def test_stale_generic_down_handling_uses_original_endpoint_after_endpoint_swap(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down(cluster, host, is_host_addition=False) + host.endpoint = new_endpoint + + executor.run_next() + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + def test_unreserved_down_handling_is_ignored_during_host_up_handling(self): session = Mock() cluster = self._make_cluster(session=session) @@ -1435,7 +1460,9 @@ def test_newer_forced_down_during_up_handling_is_preserved(self): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) assert state.epoch > first_up_epoch assert state.up_epoch == first_up_epoch assert not host.is_up @@ -1469,7 +1496,9 @@ def test_stale_failed_up_callback_does_not_cleanup_newer_down(self): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) listener.on_up.assert_not_called() assert not host.is_up assert self._state(cluster, host).up_epoch is None @@ -1503,7 +1532,8 @@ def force_down_before_cleanup(message, *args, **kwargs): host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) cluster._start_reconnector.assert_called_once_with( - host, False, expected_down_epoch=ANY) + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) assert session.remove_pool.call_count == 1 listener.on_up.assert_not_called() assert not host.is_up @@ -1561,7 +1591,9 @@ def force_down_before_reconnector_is_cleared(h, up_epoch, **kwargs): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) cluster.profile_manager.on_up.assert_not_called() cluster.control_connection.on_up.assert_not_called() old_reconnector.cancel.assert_called_once_with() @@ -1585,7 +1617,9 @@ def test_forced_down_while_reconnecting_runs_new_down_handling(self): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) assert self._state(cluster, host).down_epoch is None def test_newer_down_before_up_side_effects_suppresses_stale_up(self): @@ -1612,7 +1646,9 @@ def force_down_before_first_superseded_check(h, up_epoch): cluster.control_connection.on_down.assert_called_once_with(host) cluster.profile_manager.on_up.assert_not_called() cluster.control_connection.on_up.assert_not_called() - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) assert not host.is_up assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).down_epoch is None @@ -1820,7 +1856,9 @@ def test_down_during_up_listener_is_handled(self): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) assert not host.is_up assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).down_epoch is None @@ -1944,7 +1982,9 @@ def test_on_up_queues_after_down_is_submitted_before_worker_runs(self): session.on_down.assert_called_once_with( host, expected_endpoint=host.endpoint) listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) cluster.profile_manager.on_up.assert_called_once_with(host) cluster.control_connection.on_up.assert_called_once_with(host) assert host.is_up @@ -1966,6 +2006,7 @@ def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): Cluster.on_down( cluster, host, is_host_addition=False, expect_host_to_be_down=True) state = self._state(cluster, host) + old_endpoint = host.endpoint host.endpoint = DefaultEndPoint("127.0.0.2") @@ -1980,12 +2021,12 @@ def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): executor.run_next() - cluster.profile_manager.on_down.assert_called_once_with(host) - cluster.control_connection.on_down.assert_called_once_with(host) + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() session.on_down.assert_called_once_with( - host, expected_endpoint=host.endpoint) - listener.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() cluster.profile_manager.on_up.assert_called_once_with(host) cluster.control_connection.on_up.assert_called_once_with(host) assert host.is_up @@ -2397,7 +2438,9 @@ def test_real_down_for_unknown_host_marks_host_down(self): assert host.is_up is False cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) def test_expected_down_for_unknown_host_marks_host_down(self): cluster = self._make_cluster() @@ -2409,7 +2452,9 @@ def test_expected_down_for_unknown_host_marks_host_down(self): assert host.is_up is False cluster.profile_manager.on_down.assert_called_once_with(host) cluster.control_connection.on_down.assert_called_once_with(host) - cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) class SessionTest(unittest.TestCase): From c2275c9e4bbb863ed1d6992f34446e154ae8372b Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 18:05:08 -0400 Subject: [PATCH 28/29] cluster: fence up pool creation by endpoint --- cassandra/cluster.py | 22 +++++++++++++++------- tests/unit/test_cluster.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9da2835222..b6bda038ee 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2455,7 +2455,9 @@ def _clear_stale_reconnector(): futures, futures_results, futures_lock) for session in tuple(self.sessions): future = session.add_or_renew_pool( - host, is_host_addition=False, allow_retry_after_auth_failure=True) + host, is_host_addition=False, + allow_retry_after_auth_failure=True, + expected_endpoint=up_handling_endpoint) if future is not None: have_future = True futures.add(future) @@ -4076,7 +4078,9 @@ def _invalidate_pool_creation(self, host, expected_endpoint=None): return True return False - def add_or_renew_pool(self, host, is_host_addition, allow_retry_after_auth_failure=False): + def add_or_renew_pool(self, host, is_host_addition, + allow_retry_after_auth_failure=False, + expected_endpoint=None): """ For internal use only. """ @@ -4248,11 +4252,17 @@ def callback(pool, errors): return True with self._lock: + with host.lock: + creation_endpoint = host.endpoint + if (expected_endpoint is not None and + not self._endpoints_match( + creation_endpoint, expected_endpoint)): + return None + state = self._get_pool_creation_state(host) if state.creation_epoch is not None: - with host.lock: - endpoint_changed = not self._endpoints_match( - host.endpoint, state.endpoint) + endpoint_changed = not self._endpoints_match( + creation_endpoint, state.endpoint) if not endpoint_changed: return state.future self._invalidate_pool_creation( @@ -4260,8 +4270,6 @@ def callback(pool, errors): creation_epoch = state.advance() state.creation_epoch = creation_epoch - with host.lock: - creation_endpoint = host.endpoint state.endpoint = creation_endpoint future = self.submit(run_add_or_renew_pool) if future is None: diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 6495442b6e..284bb37603 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -579,6 +579,40 @@ def make_pool(host, distance, pool_session, endpoint=None): created_pools[1].shutdown.assert_not_called() assert session._pools[host] is created_pools[1] + def test_on_up_does_not_publish_replacement_endpoint_pool_after_endpoint_swap(self): + host = self._make_host("127.0.0.1") + host.set_down() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + original_add_or_renew_pool = Session.add_or_renew_pool.__get__( + session, Session) + + def add_after_endpoint_swap(host, *args, **kwargs): + host.endpoint = new_endpoint + return original_add_or_renew_pool(host, *args, **kwargs) + + session.add_or_renew_pool = Mock(side_effect=add_after_endpoint_swap) + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + Cluster.on_up(cluster, host) + while executor.submissions: + executor.run_next() + + assert created_pools == [] + assert session._pools == {} + session.add_or_renew_pool.assert_called_once_with( + host, is_host_addition=False, + allow_retry_after_auth_failure=True, + expected_endpoint=old_endpoint) + def test_pool_creation_publishes_before_endpoint_lock_is_released(self): host = self._make_host("127.0.0.1") new_endpoint = DefaultEndPoint("127.0.0.2") @@ -1750,7 +1784,8 @@ def queue_up_then_fail(h): cluster.control_connection.on_up.assert_called_once_with(host) session.add_or_renew_pool.assert_called_once_with( host, is_host_addition=False, - allow_retry_after_auth_failure=True) + allow_retry_after_auth_failure=True, + expected_endpoint=host.endpoint) assert host.is_up assert self._state(cluster, host).up_epoch is None assert self._state(cluster, host).pending_up_epoch is None From 77d69edae88d65e641b450b505cfd483d4d5fb2d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 18:49:22 -0400 Subject: [PATCH 29/29] session: reject stale same-endpoint pool creation --- cassandra/cluster.py | 13 +++++++++++-- tests/unit/test_cluster.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index b6bda038ee..d937676724 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4191,10 +4191,18 @@ def callback(pool, errors): metadata_host = host if isinstance(self.cluster.metadata, Metadata): metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) + if metadata_host is None: + log.debug( + "Discarding stale connection pool for host %s; " + "host id is no longer present in metadata", + host) + self._invalidate_pool_creation( + host, expected_endpoint=creation_endpoint) + discard_pool = True target_host = metadata_host if metadata_host is not None else host target_endpoint_changed = False - if target_host is not host: + if not discard_pool and target_host is not host: with target_host.lock: target_endpoint_changed = not self._endpoints_match( target_host.endpoint, creation_endpoint) @@ -4207,7 +4215,8 @@ def callback(pool, errors): self._invalidate_pool_creation( host, expected_endpoint=creation_endpoint) discard_pool = True - else: + + if not discard_pool: target_host_matches = False for pool_host in tuple(retained_pools): if pool_host is target_host: diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 284bb37603..70fdf81920 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -726,6 +726,37 @@ def make_pool(host, distance, pool_session, endpoint=None): assert session._pools == {} created_pools[0].shutdown.assert_called_once_with() + def test_stale_host_pool_creation_does_not_replace_same_endpoint_host(self): + endpoint = DefaultEndPoint("127.0.0.1") + stale_host = Host(endpoint, SimpleConvictionPolicy, + host_id=uuid.uuid4()) + replacement_host = Host(endpoint, SimpleConvictionPolicy, + host_id=uuid.uuid4()) + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + cluster.metadata = Metadata() + cluster.metadata.add_or_return_host(replacement_host) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + session._pools[replacement_host] = replacement_pool + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + executor.run_next() + + assert future.result() is False + assert session._pools[replacement_host] is replacement_pool + replacement_pool.shutdown.assert_not_called() + created_pools[0].shutdown.assert_called_once_with() + def test_remove_pool_expected_host_mismatch_invalidates_stale_creation(self): stale_host = self._make_host("127.0.0.1") replacement_host = self._make_host("127.0.0.1") @@ -1072,6 +1103,8 @@ def _make_session_with_pool(host, pool): session = Session.__new__(Session) session._lock = Lock() session._pools = {host: pool} + session.cluster = Mock() + session.cluster.metadata.all_hosts.return_value = [] session.submit = _ImmediateExecutor().submit return session