diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..d937676724 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -25,6 +25,8 @@ from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy +from contextlib import contextmanager +from contextvars import ContextVar from functools import partial, reduce, wraps from itertools import groupby, count, chain import json @@ -194,6 +196,37 @@ def _connection_reduce_fn(val,import_fn): _GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') _NOT_SET = object() +_NON_RETRYABLE_AUTH_FAILURE_ATTR = "_cassandra_non_retryable_auth_failure" +_POOL_CLEANUP_EPOCH = ContextVar("_POOL_CLEANUP_EPOCH", default=None) + + +def _make_connection_kwargs(endpoint, kwargs_dict, auth_provider_callable, port, + compression, sockopts, ssl_options, ssl_context, + cql_version, protocol_version, user_type_map, + allow_beta_protocol_version, no_compact, + application_info): + if auth_provider_callable: + kwargs_dict.setdefault('authenticator', auth_provider_callable(endpoint.address)) + + kwargs_dict.setdefault('port', port) + kwargs_dict.setdefault('compression', compression) + kwargs_dict.setdefault('sockopts', sockopts) + kwargs_dict.setdefault('ssl_options', ssl_options) + kwargs_dict.setdefault('ssl_context', ssl_context) + kwargs_dict.setdefault('cql_version', cql_version) + kwargs_dict.setdefault('protocol_version', protocol_version) + kwargs_dict.setdefault('user_type_map', user_type_map) + kwargs_dict.setdefault('allow_beta_protocol_version', allow_beta_protocol_version) + kwargs_dict.setdefault('no_compact', no_compact) + kwargs_dict.setdefault('application_info', application_info) + + return kwargs_dict + + +def _connection_factory_for_endpoint(endpoint, connect_timeout, connection_factory, + args, kwargs, connection_options): + kwargs = _make_connection_kwargs(endpoint, kwargs.copy(), *connection_options) + return connection_factory(endpoint, connect_timeout, *args, **kwargs) class NoHostAvailable(Exception): @@ -234,6 +267,7 @@ def new_f(self, *args, **kwargs): try: future = self.executor.submit(f, self, *args, **kwargs) future.add_done_callback(_future_completed) + return future except Exception: log.exception("Failed to submit task to executor") @@ -251,6 +285,173 @@ def _discard_cluster_shutdown(cluster): _clusters_for_shutdown.discard(cluster) +class _EventFenceState(object): + """ + Tracks a monotonically increasing epoch and named in-flight events. + + The caller owns any domain-specific locking around this state. The epoch + represents the latest observed event for a key, while named events record + which epoch currently owns a deferred side effect. + """ + + __slots__ = ("epoch", "_events") + + def __init__(self): + self.epoch = 0 + self._events = {} + + def advance(self): + self.epoch += 1 + return self.epoch + + def get_event(self, event): + return self._events.get(event) + + def set_event(self, event, epoch=None): + if epoch is None: + epoch = self.epoch + self._events[event] = epoch + return epoch + + def clear_event(self, event, epoch=None): + if epoch is not None and self._events.get(event) != epoch: + return False + self._events.pop(event, None) + return True + + def event_is_current(self, event, epoch): + return self._events.get(event) == epoch and self.epoch == epoch + + def _set_or_clear_event(self, event, epoch): + if epoch is None: + self.clear_event(event) + else: + self.set_event(event, epoch) + + +class _HostLivenessState(_EventFenceState): + _UP = "up" + _DOWN = "down" + _PENDING_UP = "pending_up" + _PENDING_DOWN = "pending_down" + + __slots__ = ("up_endpoint", "down_endpoint", + "pending_up_endpoint", "pending_down_endpoint") + + def __init__(self): + _EventFenceState.__init__(self) + self.up_endpoint = None + self.down_endpoint = None + self.pending_up_endpoint = None + self.pending_down_endpoint = None + + @property + def up_epoch(self): + return self.get_event(self._UP) + + @up_epoch.setter + def up_epoch(self, epoch): + self._set_or_clear_event(self._UP, epoch) + + @property + def down_epoch(self): + return self.get_event(self._DOWN) + + @down_epoch.setter + def down_epoch(self, epoch): + self._set_or_clear_event(self._DOWN, epoch) + + @property + def pending_up_epoch(self): + return self.get_event(self._PENDING_UP) + + @pending_up_epoch.setter + def pending_up_epoch(self, epoch): + self._set_or_clear_event(self._PENDING_UP, epoch) + + @property + def pending_down_epoch(self): + return self.get_event(self._PENDING_DOWN) + + @pending_down_epoch.setter + def pending_down_epoch(self, epoch): + self._set_or_clear_event(self._PENDING_DOWN, epoch) + + +class _PoolCreationState(_EventFenceState): + _CREATE = "create" + + __slots__ = ("future", "endpoint") + + def __init__(self): + _EventFenceState.__init__(self) + self.future = None + self.endpoint = None + + @property + def creation_epoch(self): + return self.get_event(self._CREATE) + + @creation_epoch.setter + def creation_epoch(self, epoch): + self._set_or_clear_event(self._CREATE, epoch) + + +class _IdentityWeakKeyDictionary(object): + """ + Weak mapping that uses object identity instead of ``__hash__``/``__eq__``. + Host hashes are endpoint-based, and endpoints can change during ring refresh. + """ + + def __init__(self): + self._items = {} + + def get(self, key, default=None): + key_id = id(key) + item = self._items.get(key_id) + if item is None: + return default + + key_ref, value = item + if key_ref() is key: + return value + + self._items.pop(key_id, None) + return default + + def __setitem__(self, key, value): + key_id = id(key) + self_ref = weakref.ref(self) + + def remove(ref): + self = self_ref() + if self is not None: + item = self._items.get(key_id) + if item is not None and item[0] is ref: + self._items.pop(key_id, None) + + self._items[key_id] = (weakref.ref(key, remove), value) + + +class _EventFenceMap(object): + """ + Identity-keyed store for per-object event fence state. + """ + + def __init__(self, state_factory=_EventFenceState): + self._lock = Lock() + self._states = _IdentityWeakKeyDictionary() + self._state_factory = state_factory + + def get_state(self, key): + with self._lock: + state = self._states.get(key) + if state is None: + state = self._state_factory() + self._states[key] = state + return state + + def _shutdown_clusters(): clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" for cluster in clusters: @@ -1489,6 +1690,7 @@ def __init__(self, self.sessions = WeakSet() self.metadata = Metadata() self.control_connection = None + self._host_liveness = _EventFenceMap(_HostLivenessState) self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() @@ -1694,26 +1896,23 @@ def connection_factory(self, endpoint, host_conn = None, *args, **kwargs): return self.connection_class.factory(endpoint, self.connect_timeout, host_conn, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): - kwargs = self._make_connection_kwargs(host.endpoint, kwargs) - return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) + connection_options = ( + self._auth_provider_callable, self.port, self.compression, + self.sockopts, self.ssl_options, self.ssl_context, + self.cql_version, self.protocol_version, self._user_types, + self.allow_beta_protocol_version, self.no_compact, + self.application_info) + return partial(_connection_factory_for_endpoint, host.endpoint, + self.connect_timeout, self.connection_class.factory, + args, kwargs.copy(), connection_options) def _make_connection_kwargs(self, endpoint, kwargs_dict): - if self._auth_provider_callable: - kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) - - kwargs_dict.setdefault('port', self.port) - kwargs_dict.setdefault('compression', self.compression) - kwargs_dict.setdefault('sockopts', self.sockopts) - kwargs_dict.setdefault('ssl_options', self.ssl_options) - kwargs_dict.setdefault('ssl_context', self.ssl_context) - kwargs_dict.setdefault('cql_version', self.cql_version) - kwargs_dict.setdefault('protocol_version', self.protocol_version) - kwargs_dict.setdefault('user_type_map', self._user_types) - kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) - kwargs_dict.setdefault('no_compact', self.no_compact) - kwargs_dict.setdefault('application_info', self.application_info) - - return kwargs_dict + return _make_connection_kwargs( + endpoint, kwargs_dict, self._auth_provider_callable, self.port, + self.compression, self.sockopts, self.ssl_options, self.ssl_context, + self.cql_version, self.protocol_version, self._user_types, + self.allow_beta_protocol_version, self.no_compact, + self.application_info) def protocol_downgrade(self, host_endpoint, previous_version): if self._protocol_version_explicit: @@ -1862,15 +2061,207 @@ def _session_register_user_types(self, session): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) - def _cleanup_failed_on_up_handling(self, host): - self.profile_manager.on_down(host) - self.control_connection.on_down(host) + def _cleanup_failed_on_up_handling(self, host, start_reconnector=True, + expected_endpoint=None, expected_epoch=None): + if expected_epoch is not None: + with host.lock: + if self._get_host_liveness_state(host).epoch != expected_epoch: + log.debug("Ignoring stale failed up cleanup for node %s", host) + return + + endpoint_changed = False + if expected_endpoint is None: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + else: + with host.lock: + endpoint_changed = not self._endpoints_match(host.endpoint, expected_endpoint) + if endpoint_changed: + log.debug("Not signalling down for stale up handling on node %s; endpoint changed from %s", + host, expected_endpoint) + else: + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + for session in tuple(self.sessions): + with self._pool_cleanup_epoch(host, expected_epoch): + future = session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) + if future: + future.add_done_callback( + lambda f, session=session: session.update_created_pools()) + + cleanup_is_current = True + if expected_epoch is not None: + with host.lock: + cleanup_is_current = self._get_host_liveness_state(host).epoch == expected_epoch + + if start_reconnector and not endpoint_changed and cleanup_is_current: + self._start_reconnector( + host, is_host_addition=False, expected_endpoint=expected_endpoint) + + def _get_host_liveness_state(self, host): + try: + fences = self._host_liveness + except AttributeError: + fences = self._host_liveness = _EventFenceMap(_HostLivenessState) + return fences.get_state(host) + + @contextmanager + def _pool_cleanup_epoch(self, host, expected_epoch=None, + allow_current_down=False): + if expected_epoch is None: + yield + return + + token = _POOL_CLEANUP_EPOCH.set( + (host, expected_epoch, allow_current_down)) + try: + yield + finally: + _POOL_CLEANUP_EPOCH.reset(token) + + def _up_handling_was_superseded(self, host, up_epoch): + state = self._get_host_liveness_state(host) + if not state.event_is_current(_HostLivenessState._UP, up_epoch): + return True + + expected_endpoint = state.up_endpoint + return expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint) + + def _up_handling_is_superseded(self, host, up_epoch): + with host.lock: + superseded = self._up_handling_was_superseded(host, up_epoch) + if superseded: + log.debug("Ignoring superseded up handling for node %s", host) + return superseded + + def _get_reconnector_for_current_up_handling(self, host, up_epoch, + expected_reconnector=_NOT_SET): + with host.lock: + if self._up_handling_was_superseded(host, up_epoch): + reconnector = host._reconnection_handler + # Only cancel the handler observed when up handling started. + # A newer down event may already have installed its own reconnector. + if expected_reconnector is _NOT_SET or reconnector is expected_reconnector: + host._reconnection_handler = None + if reconnector: + reconnector.cancel() + log.debug("Ignoring superseded up handling for node %s", host) + return None, True + reconnector = host._reconnection_handler + host._reconnection_handler = None + return reconnector, False + + def _clear_up_handling(self, host, up_epoch=None): + state = self._get_host_liveness_state(host) + if state.clear_event(_HostLivenessState._UP, up_epoch): + state.up_endpoint = None + return True + return False + + def _cleanup_superseded_up_handling(self, host, expected_endpoint=None, expected_epoch=None): + if expected_epoch is not None: + with host.lock: + if self._get_host_liveness_state(host).epoch != expected_epoch: + log.debug("Ignoring stale superseded up cleanup for node %s", host) + return + for session in tuple(self.sessions): - session.remove_pool(host) + with self._pool_cleanup_epoch(host, expected_epoch): + future = session.remove_pool( + host, expected_host=host, expected_endpoint=expected_endpoint) + if future: + future.add_done_callback( + lambda f, session=session: session.update_created_pools()) + + def _pop_pending_node_up_if_ready(self, host): + state = self._get_host_liveness_state(host) + if state.pending_up_epoch is None: + state.pending_up_endpoint = None + return None + if host.is_up: + state.pending_up_epoch = None + state.pending_up_endpoint = None + return None + if state.up_epoch is not None or state.down_epoch is not None: + return None + + pending_up_epoch = state.pending_up_epoch + pending_up_endpoint = state.pending_up_endpoint + # Leave the pending marker in place until on_up() reacquires host.lock so + # a newer down signal can still invalidate this replay. + return pending_up_epoch, pending_up_endpoint + + def _handle_pending_node_up(self, host, pending_up): + if pending_up is not None: + if isinstance(pending_up, tuple): + pending_up_epoch, pending_up_endpoint = pending_up + else: + pending_up_epoch = pending_up + pending_up_endpoint = None + log.debug("Handling queued up status of node %s", host) + self._on_up(host, expected_epoch=pending_up_epoch, + expected_endpoint=pending_up_endpoint) + + def _clear_down_handling(self, host, down_epoch=None): + state = self._get_host_liveness_state(host) + if state.clear_event(_HostLivenessState._DOWN, down_epoch): + state.down_endpoint = None + return True + return False + + def _clear_pending_down(self, state): + state.pending_down_epoch = None + state.pending_down_endpoint = None + + def _pop_pending_node_down_if_ready(self, host): + state = self._get_host_liveness_state(host) + if state.pending_down_epoch is None: + state.pending_down_endpoint = None + return None + if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + return None - self._start_reconnector(host, is_host_addition=False) + pending_down_epoch = state.pending_down_epoch + pending_down_endpoint = state.pending_down_endpoint + if state.epoch != pending_down_epoch: + self._clear_pending_down(state) + return None + + self._clear_pending_down(state) + return pending_down_epoch, pending_down_endpoint + + def _handle_pending_node_down(self, host, pending_down, is_host_addition): + if pending_down is None: + return False + _pending_down_epoch, pending_down_endpoint = pending_down + log.debug("Handling queued down status of node %s", host) + self.on_down( + host, is_host_addition=is_host_addition, + expect_host_to_be_down=True, + expected_endpoint=pending_down_endpoint) + return True - def _on_up_future_completed(self, host, futures, results, lock, finished_future): + def _finish_superseded_up_handling(self, host, up_epoch, expected_endpoint=None): + self._cleanup_superseded_up_handling( + host, expected_endpoint=expected_endpoint, expected_epoch=up_epoch) + + pending_up_epoch = None + with host.lock: + if self._clear_up_handling(host, up_epoch): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + + self._handle_pending_node_up(host, pending_up_epoch) + + def _finish_up_if_superseded(self, host, up_epoch, expected_endpoint=None): + if self._up_handling_is_superseded(host, up_epoch): + self._finish_superseded_up_handling( + host, up_epoch, expected_endpoint=expected_endpoint) + return True + return False + + def _on_up_future_completed(self, host, up_handling_revision, up_handling_endpoint, + futures, results, lock, finished_future): with lock: futures.discard(finished_future) @@ -1884,30 +2275,69 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) try: # all futures have completed at this point + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return + for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) - self._cleanup_failed_on_up_handling(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return + self._cleanup_failed_on_up_handling( + host, expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) - self._cleanup_failed_on_up_handling(host) + with host.lock: + # Only suppress retries for the auth-failure quarantine. + start_reconnector = not self._has_non_retryable_auth_failure(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return + self._cleanup_failed_on_up_handling( + host, start_reconnector=start_reconnector, + expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) return log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners - host.set_up() + superseded = False + with host.lock: + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Ignoring superseded up handling for node %s", host) + superseded = True + else: + host.set_up() + self._clear_up_handling(host, up_handling_revision) + if superseded: + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) + return for listener in self.listeners: listener.on_up(host) finally: + pending_up_epoch = None with host.lock: - host._currently_handling_node_up = False + if self._clear_up_handling(host, up_handling_revision): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + self._handle_pending_node_up(host, pending_up_epoch) # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() + self._set_non_retryable_auth_failure(host, False) + return - def on_up(self, host): + def on_up(self, host, expected_endpoint=None, expected_reconnector=None): + return self._on_up(host, expected_endpoint=expected_endpoint, + expected_reconnector=expected_reconnector) + + def _on_up(self, host, expected_epoch=None, expected_endpoint=None, + expected_reconnector=None): """ Intended for internal use only. """ @@ -1915,16 +2345,61 @@ def on_up(self, host): return log.debug("Waiting to acquire lock for handling up status of node %s", host) + + def _clear_stale_reconnector(): + # Stale callbacks can outlive the reconnector callback path. + if (expected_reconnector is not None and + host._reconnection_handler is expected_reconnector): + host._reconnection_handler = None + with host.lock: - if host._currently_handling_node_up: - log.debug("Another thread is already handling up status of node %s", host) + state = self._get_host_liveness_state(host) + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Ignoring stale queued up handling for node %s; endpoint changed from %s", + host, expected_endpoint) + _clear_stale_reconnector() + return + + if (expected_epoch is not None and + (state.epoch != expected_epoch or state.pending_up_epoch != expected_epoch)): + log.debug("Ignoring stale queued up handling for node %s", host) + _clear_stale_reconnector() + return + + if state.down_epoch is not None: + log.debug("Down status is being handled for node %s; queueing up handling", host) + state.pending_up_epoch = state.epoch + state.pending_up_endpoint = expected_endpoint + return + + if state.up_epoch is not None: + up_handling_revision = state.up_epoch + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Superseded up handling is still finishing for node %s; " + "queueing up handling", host) + # Endpoint swap without an epoch bump needs a fresh epoch so + # the replayed up handling is not cleared by the stale finish. + if state.epoch == up_handling_revision: + state.advance() + state.pending_up_epoch = state.epoch + state.pending_up_endpoint = expected_endpoint + else: + log.debug("Another thread is already handling up status of node %s", host) return if host.is_up: log.debug("Host %s was already marked up", host) + state.pending_up_epoch = None + state.pending_up_endpoint = None return - host._currently_handling_node_up = True + state.pending_up_epoch = None + state.pending_up_endpoint = None + up_handling_revision = state.epoch + up_handling_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint + state.up_epoch = up_handling_revision + state.up_endpoint = up_handling_endpoint + up_handling_reconnector = host._reconnection_handler log.debug("Starting to handle up status of node %s", host) have_future = False @@ -1932,70 +2407,139 @@ def on_up(self, host): try: log.info("Host %s may be up; will prepare queries and open connection pool", host) - reconnector = host.get_and_set_reconnection_handler(None) + reconnector, superseded = self._get_reconnector_for_current_up_handling( + host, up_handling_revision, expected_reconnector=up_handling_reconnector) + if superseded: + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) + return futures if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() if self.profile_manager.distance(host) != HostDistance.IGNORED: self._prepare_all_queries(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return futures log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): - session.remove_pool(host) + with self._pool_cleanup_epoch(host, expected_epoch=up_handling_revision): + session.remove_pool( + host, expected_host=host, expected_endpoint=up_handling_endpoint) + + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return futures log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return futures + log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) + if self._finish_up_if_superseded( + host, up_handling_revision, expected_endpoint=up_handling_endpoint): + return futures + log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] - callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback = partial(self._on_up_future_completed, host, up_handling_revision, + up_handling_endpoint, + futures, futures_results, futures_lock) for session in tuple(self.sessions): - future = session.add_or_renew_pool(host, is_host_addition=False) + future = session.add_or_renew_pool( + host, is_host_addition=False, + allow_retry_after_auth_failure=True, + expected_endpoint=up_handling_endpoint) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + for future in tuple(futures): + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: future.cancel() - self._cleanup_failed_on_up_handling(host) + self._cleanup_failed_on_up_handling( + host, expected_endpoint=up_handling_endpoint, + expected_epoch=up_handling_revision) + pending_up_epoch = None with host.lock: - host._currently_handling_node_up = False + if self._clear_up_handling(host, up_handling_revision): + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + self._handle_pending_node_up(host, pending_up_epoch) raise else: if not have_future: + superseded = False with host.lock: - host.set_up() - host._currently_handling_node_up = False + if self._up_handling_was_superseded(host, up_handling_revision): + log.debug("Ignoring superseded up handling for node %s", host) + superseded = True + else: + host.set_up() + self._clear_up_handling(host, up_handling_revision) + if superseded: + self._finish_superseded_up_handling( + host, up_handling_revision, expected_endpoint=up_handling_endpoint) + return futures + + for listener in self.listeners: + listener.on_up(host) + + for session in tuple(self.sessions): + session.update_created_pools() + + self._set_non_retryable_auth_failure(host, False) # for testing purposes return futures - def _start_reconnector(self, host, is_host_addition): + def _start_reconnector(self, host, is_host_addition, expected_down_epoch=None, + expected_endpoint=None): if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() - # in order to not hold references to this Cluster open and prevent - # proper shutdown when the program ends, we'll just make a closure - # of the current Cluster attributes to create new Connections with - conn_factory = self._make_connection_factory(host) + with host.lock: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + return + + # in order to not hold references to this Cluster open and prevent + # proper shutdown when the program ends, we'll just make a closure + # of the current Cluster attributes to create new Connections with + conn_factory = self._make_connection_factory(host) + + if expected_down_epoch is not None: + state = self._get_host_liveness_state(host) + if state.down_epoch != expected_down_epoch: + log.debug("Not starting reconnector for host %s; down handling is no longer current", host) + return + + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + return - reconnector = _HostReconnectionHandler( - host, conn_factory, is_host_addition, self.on_add, self.on_up, - self.scheduler, schedule, host.get_and_set_reconnection_handler, - new_handler=None) + reconnector = _HostReconnectionHandler( + host, conn_factory, is_host_addition, self.on_add, self.on_up, + self.scheduler, schedule, host.get_and_set_reconnection_handler, + new_handler=None) - old_reconnector = host.get_and_set_reconnection_handler(reconnector) + old_reconnector = host._reconnection_handler + host._reconnection_handler = reconnector if old_reconnector: log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() @@ -2004,45 +2548,200 @@ def _start_reconnector(self, host, is_host_addition): reconnector.start() @run_in_executor - def on_down_potentially_blocking(self, host, is_host_addition): - self.profile_manager.on_down(host) - self.control_connection.on_down(host) - for session in tuple(self.sessions): - session.on_down(host) + def on_down_potentially_blocking( + self, host: Host, is_host_addition: bool, + down_epoch: Optional[int] = None, + expected_endpoint: Optional[EndPoint] = None, + profile_manager_already_notified: bool = False, + control_connection_already_notified: bool = False) -> Any: + pending_up_epoch = None + try: + down_endpoint = None + with host.lock: + state = self._get_host_liveness_state(host) + owns_reserved_down_handling = down_epoch is not None and state.down_epoch == down_epoch + if down_epoch is None: + if host.is_up or state.up_epoch is not None or state.down_epoch is not None: + log.debug("Ignoring stale down handling for host %s", host) + return + down_epoch = state.epoch + state.down_epoch = down_epoch + state.down_endpoint = expected_endpoint + elif not owns_reserved_down_handling: + log.debug("Ignoring stale down handling for host %s", host) + return + down_endpoint = host.endpoint + endpoint_matches = expected_endpoint is None or self._endpoints_match( + down_endpoint, expected_endpoint) - for listener in self.listeners: - listener.on_down(host) + if endpoint_matches: + with host.lock: + if not self._endpoints_match(host.endpoint, down_endpoint): + log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", + host, down_endpoint) + endpoint_matches = False + else: + if not profile_manager_already_notified: + self.profile_manager.on_down(host) + if not control_connection_already_notified: + self.control_connection.on_down(host) + else: + log.debug("Not signalling down for stale down handling on node %s; endpoint changed from %s", + host, expected_endpoint) + + for session in tuple(self.sessions): + with self._pool_cleanup_epoch( + host, expected_epoch=down_epoch, + allow_current_down=True): + if expected_endpoint is None: + session.on_down(host, expected_endpoint=down_endpoint) + else: + session.on_down(host, expected_endpoint=expected_endpoint) + + notify_listeners = False + if endpoint_matches: + if expected_endpoint is None: + notify_listeners = True + else: + with host.lock: + if not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Ignoring stale down handling for host %s; endpoint changed from %s", + host, expected_endpoint) + endpoint_matches = False + else: + notify_listeners = True + if notify_listeners: + for listener in self.listeners: + listener.on_down(host) + + with host.lock: + start_reconnector = (endpoint_matches and + self._get_host_liveness_state(host).down_epoch == down_epoch) + if (start_reconnector and expected_endpoint is not None and + not self._endpoints_match(host.endpoint, expected_endpoint)): + log.debug("Not starting reconnector for host %s; endpoint changed from %s", + host, expected_endpoint) + start_reconnector = False + if start_reconnector: + if expected_endpoint is None: + self._start_reconnector(host, is_host_addition, expected_down_epoch=down_epoch) + else: + self._start_reconnector( + host, is_host_addition, expected_down_epoch=down_epoch, + expected_endpoint=expected_endpoint) + else: + log.debug("Not starting reconnector for removed host %s", host) + finally: + pending_down = None + pending_up_epoch = None + with host.lock: + if down_epoch is not None and self._clear_down_handling(host, down_epoch): + pending_down = self._pop_pending_node_down_if_ready(host) + if pending_down is None: + pending_up_epoch = self._pop_pending_node_up_if_ready(host) - self._start_reconnector(host, is_host_addition) + if not self._handle_pending_node_down(host, pending_down, is_host_addition): + self._handle_pending_node_up(host, pending_up_epoch) - def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + def on_down(self, host, is_host_addition, expect_host_to_be_down=False, + expected_endpoint=None, profile_manager_already_notified=False, + control_connection_already_notified=False): """ Intended for internal use only. """ if self.is_shutdown: return - with host.lock: - was_up = host.is_up - - # ignore down signals if we have open pools to the host - # this is to avoid closing pools when a control connection host became isolated - if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: + if (not expect_host_to_be_down and + self._discount_down_events and + self.profile_manager.distance(host) != HostDistance.IGNORED): + with host.lock: + host_endpoint = host.endpoint + discount_endpoint = host_endpoint + if expected_endpoint is not None: + if self._endpoints_match(host_endpoint, expected_endpoint): + discount_endpoint = expected_endpoint + else: + discount_endpoint = None + if discount_endpoint is not None: connected = False for session in tuple(self.sessions): - pool_states = session.get_pool_state() - pool_state = pool_states.get(host) - if pool_state: - connected |= pool_state['open_count'] > 0 + # Host equality is endpoint-based; scan by identity to avoid + # hiding the live pool behind a stale equal key. Do not hold + # host.lock while taking session._lock; update_created_pools() + # takes the locks in the opposite order. + pool = session._get_pool_by_host_identity( + host, expected_endpoint=discount_endpoint) + if pool is not None and pool.open_count > 0: + connected = True + break if connected: + with host.lock: + if self._endpoints_match(host.endpoint, discount_endpoint): + return + + with host.lock: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Ignoring stale down signal for host %s; endpoint changed from %s", + host, expected_endpoint) + return + + was_up = host.is_up + state = self._get_host_liveness_state(host) + down_endpoint = expected_endpoint if expected_endpoint is not None else host.endpoint + pending_down_endpoint = None + if state.down_endpoint is not None: + if not self._endpoints_match(state.down_endpoint, down_endpoint): + pending_down_endpoint = down_endpoint + + if pending_down_endpoint is not None: + state.advance() + state.pending_up_epoch = None + state.pending_up_endpoint = None + state.pending_down_epoch = state.epoch + state.pending_down_endpoint = pending_down_endpoint + host.set_down() + return + + if not expect_host_to_be_down: + if was_up is False: + if state.pending_up_epoch is not None: + state.advance() + state.pending_up_epoch = None + state.pending_up_endpoint = None + host.set_down() return + state.advance() + state.pending_up_epoch = None + state.pending_up_endpoint = None + self._clear_pending_down(state) host.set_down() - if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + down_epoch = state.epoch + if state.down_epoch is not None: + return + if (not expect_host_to_be_down and + host.is_currently_reconnecting() and + state.up_epoch is None): return + state.down_epoch = down_epoch + state.down_endpoint = down_endpoint log.warning("Host %s has been marked down", host) - self.on_down_potentially_blocking(host, is_host_addition) + future = self.on_down_potentially_blocking( + host, is_host_addition, down_epoch, down_endpoint, + profile_manager_already_notified, + control_connection_already_notified) + if future is None: + pending_down = None + pending_up_epoch = None + with host.lock: + if self._clear_down_handling(host, down_epoch): + pending_down = self._pop_pending_node_down_if_ready(host) + if pending_down is None: + pending_up_epoch = self._pop_pending_node_up_if_ready(host) + if not self._handle_pending_node_down(host, pending_down, is_host_addition): + self._handle_pending_node_up(host, pending_up_epoch) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2114,12 +2813,26 @@ def _finalize_add(self, host, set_up=True): for session in tuple(self.sessions): session.update_created_pools() + if set_up: + self._set_non_retryable_auth_failure(host, False) + def on_remove(self, host): if self.is_shutdown: return log.debug("[cluster] Removing host %s", host) - host.set_down() + with host.lock: + state = self._get_host_liveness_state(host) + state.advance() + state.pending_up_epoch = None + state.pending_up_endpoint = None + self._clear_pending_down(state) + state.up_epoch = None + state.up_endpoint = None + state.down_epoch = None + state.down_endpoint = None + host.set_down() + self._set_non_retryable_auth_failure(host, False) self.profile_manager.on_remove(host) for session in tuple(self.sessions): session.on_remove(host) @@ -2131,10 +2844,99 @@ def on_remove(self, host): if reconnection_handler: reconnection_handler.cancel() - def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): - is_down = host.signal_connection_failure(connection_exc) - if is_down: - self.on_down(host, is_host_addition, expect_host_to_be_down) + @staticmethod + def _is_authentication_failure(connection_exc): + return (isinstance(connection_exc, AuthenticationFailed) or + isinstance(getattr(connection_exc, "__cause__", None), AuthenticationFailed)) + + @staticmethod + def _endpoint_address_key(endpoint): + if endpoint is None: + return None + if isinstance(endpoint, tuple): + return endpoint[:2] + address = getattr(endpoint, "address", None) + port = getattr(endpoint, "port", None) + if address is not None or port is not None: + return address, port + return endpoint + + @classmethod + def _endpoint_key(cls, endpoint): + if endpoint is None: + return None + if isinstance(endpoint, tuple): + return endpoint[:2] + address_key = cls._endpoint_address_key(endpoint) + if not isinstance(address_key, tuple): + return endpoint + if endpoint.__class__ is DefaultEndPoint: + return address_key + # Non-default endpoint classes can carry connection identity that + # address/port alone does not preserve, e.g. SNI or client-routes IDs. + endpoint_key = (endpoint.__class__,) + address_key + server_name = getattr(endpoint, "_server_name", None) + if server_name is not None: + return endpoint_key + (server_name,) + host_id = getattr(endpoint, "host_id", None) + if host_id is not None: + return endpoint_key + (host_id,) + if isinstance(endpoint, EndPoint): + return endpoint_key + (endpoint,) + return endpoint + + @classmethod + def _endpoints_match(cls, endpoint, expected_endpoint): + if isinstance(endpoint, tuple) or isinstance(expected_endpoint, tuple): + return (cls._endpoint_address_key(endpoint) == + cls._endpoint_address_key(expected_endpoint)) + return cls._endpoint_key(endpoint) == cls._endpoint_key(expected_endpoint) + + @classmethod + def _set_non_retryable_auth_failure(cls, host, enabled): + if enabled: + setattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR, cls._endpoint_key(host.endpoint)) + elif hasattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR): + delattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR) + + @classmethod + def _has_non_retryable_auth_failure(cls, host): + failure_endpoint = getattr(host, _NON_RETRYABLE_AUTH_FAILURE_ATTR, None) + return failure_endpoint is not None and failure_endpoint == cls._endpoint_key(host.endpoint) + + def signal_connection_failure(self, host, connection_exc, is_host_addition, + expect_host_to_be_down=False, expected_endpoint=None): + signal_down = False + with host.lock: + if expected_endpoint is not None and not self._endpoints_match(host.endpoint, expected_endpoint): + log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", + host, expected_endpoint) + return False + + if expected_endpoint is not None: + if isinstance(expected_endpoint, tuple): + current_host = self.metadata.get_host(*expected_endpoint[:2]) + else: + current_host = self.metadata.get_host(expected_endpoint) + if current_host is not None and current_host is not host: + log.debug("Ignoring stale connection failure for host %s; endpoint reassigned to %s", + host, current_host) + return False + + is_auth_failure = self._is_authentication_failure(connection_exc) + is_down = host.signal_connection_failure(connection_exc) + if host.is_up is None and is_auth_failure: + # Never-up auth failures are terminal for the endpoint even if conviction policy + # has not yet decided the host is down. Do not run normal down handling + # for a host that was never marked up. + self._set_non_retryable_auth_failure(host, True) + return is_down + if is_down: + signal_down = True + if signal_down: + self.on_down( + host, is_host_addition, expect_host_to_be_down, + expected_endpoint=expected_endpoint) return is_down def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): @@ -2615,6 +3417,7 @@ def __init__(self, cluster, hosts, keyspace=None): self._lock = RLock() self._pools = {} + self._pool_creation_fences = _EventFenceMap(_PoolCreationState) self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] @@ -3232,33 +4035,116 @@ def __del__(self): # when cluster.shutdown() is called explicitly. pass - def add_or_renew_pool(self, host, is_host_addition): + def _get_pool_creation_state(self, host): + try: + fences = self._pool_creation_fences + except AttributeError: + fences = self._pool_creation_fences = _EventFenceMap(_PoolCreationState) + return fences.get_state(host) + + def _endpoints_match(self, endpoint, expected_endpoint): + matches = self.cluster._endpoints_match(endpoint, expected_endpoint) + if isinstance(matches, bool): + return matches + return endpoint == expected_endpoint + + def _has_non_retryable_auth_failure(self, host): + return self.cluster._has_non_retryable_auth_failure(host) + + def _get_host_liveness_state(self, host): + return self.cluster._get_host_liveness_state(host) + + def _pool_creation_is_current(self, host, creation_epoch): + state = self._get_pool_creation_state(host) + return state.event_is_current(_PoolCreationState._CREATE, creation_epoch) + + def _clear_pool_creation(self, host, creation_epoch): + state = self._get_pool_creation_state(host) + if state.clear_event(_PoolCreationState._CREATE, creation_epoch): + state.future = None + state.endpoint = None + return True + return False + + def _invalidate_pool_creation(self, host, expected_endpoint=None): + state = self._get_pool_creation_state(host) + if state.creation_epoch is not None: + if expected_endpoint is not None and not self._endpoints_match(state.endpoint, expected_endpoint): + return False + state.advance() + state.creation_epoch = None + state.future = None + state.endpoint = None + return True + return False + + def add_or_renew_pool(self, host, is_host_addition, + allow_retry_after_auth_failure=False, + expected_endpoint=None): """ For internal use only. """ distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None + if (not is_host_addition and not allow_retry_after_auth_failure and + self._has_non_retryable_auth_failure(host)): + # Auth failure quarantine is endpoint-scoped; same endpoint survives later + # down events until a successful add/up clears it. + return None + + creation_epoch = None + creation_endpoint = None def run_add_or_renew_pool(): try: - new_pool = HostConnection(host, distance, self) + new_pool = HostConnection( + host, distance, self, endpoint=creation_endpoint) except AuthenticationFailed as auth_exc: - conn_exc = ConnectionException(str(auth_exc), endpoint=host) - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + conn_exc = ConnectionException(str(auth_exc), endpoint=creation_endpoint) + conn_exc.__cause__ = auth_exc + with self._lock: + signal_failure = self._pool_creation_is_current(host, creation_epoch) + if signal_failure: + self._clear_pool_creation(host, creation_epoch) + if signal_failure: + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, + expected_endpoint=creation_endpoint) return False - except Exception as conn_exc: + except Exception as pool_exc: log.warning("Failed to create connection pool for new host %s:", - host, exc_info=conn_exc) + host, exc_info=pool_exc) + conn_exc = pool_exc + try: + conn_exc.endpoint = creation_endpoint + except AttributeError: + conn_exc = ConnectionException(str(pool_exc), endpoint=creation_endpoint) + conn_exc.__cause__ = pool_exc # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created - self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + with self._lock: + signal_failure = self._pool_creation_is_current(host, creation_epoch) + if signal_failure: + self._clear_pool_creation(host, creation_epoch) + if signal_failure: + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, expect_host_to_be_down=True, + expected_endpoint=creation_endpoint) return False + pool_endpoint = self._get_pool_endpoint(new_pool, creation_endpoint) + new_pool.endpoint = pool_endpoint - previous = self._pools.get(host) + previous_pools = [] + discard_pool = False + signal_down = False + reuse_existing_pool = False with self._lock: while new_pool._keyspace != self.keyspace: + if not self._pool_creation_is_current(host, creation_epoch): + discard_pool = True + break + self._lock.release() set_keyspace_event = Event() errors_returned = [] @@ -3269,31 +4155,274 @@ def callback(pool, errors): new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) + self._lock.acquire() + if not self._pool_creation_is_current(host, creation_epoch): + discard_pool = True + break if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) - new_pool.shutdown() - self._lock.acquire() - return False - self._lock.acquire() - self._pools[host] = new_pool + self._clear_pool_creation(host, creation_epoch) + signal_down = True + break + + if not discard_pool and not signal_down: + if not self._pool_creation_is_current(host, creation_epoch): + discard_pool = True + else: + with host.lock: + endpoint_changed = not self._endpoints_match(host.endpoint, creation_endpoint) + if endpoint_changed: + log.debug( + "Discarding stale connection pool for host %s; endpoint changed from %s", + host, creation_endpoint) + self._invalidate_pool_creation(host, expected_endpoint=creation_endpoint) + discard_pool = True + else: + # Rebuild by identity so endpoint hash changes do not + # leave stale pool entries behind. + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if pool_host is host: + previous_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + + # Keep the current metadata host keyed by identity. + metadata_host = host + if isinstance(self.cluster.metadata, Metadata): + metadata_host = self.cluster.metadata.get_host_by_host_id(host.host_id) + if metadata_host is None: + log.debug( + "Discarding stale connection pool for host %s; " + "host id is no longer present in metadata", + host) + self._invalidate_pool_creation( + host, expected_endpoint=creation_endpoint) + discard_pool = True + + target_host = metadata_host if metadata_host is not None else host + target_endpoint_changed = False + if not discard_pool and target_host is not host: + with target_host.lock: + target_endpoint_changed = not self._endpoints_match( + target_host.endpoint, creation_endpoint) + + if target_endpoint_changed: + log.debug( + "Discarding stale connection pool for host %s; " + "metadata host endpoint changed from %s", + host, creation_endpoint) + self._invalidate_pool_creation( + host, expected_endpoint=creation_endpoint) + discard_pool = True + + if not discard_pool: + target_host_matches = False + for pool_host in tuple(retained_pools): + if pool_host is target_host: + target_host_matches = True + elif pool_host == target_host: + previous_pools.append(retained_pools.pop(pool_host)) + + if target_host_matches: + reuse_existing_pool = True + else: + source_host = new_pool.host + if (source_host is not target_host and + target_host.sharding_info is None): + target_host.sharding_info = source_host.sharding_info + new_pool.host = target_host + retained_pools[target_host] = new_pool + self._pools = retained_pools + self._clear_pool_creation(host, creation_epoch) + + if reuse_existing_pool: + log.debug("Reusing existing connection pool for host %s", host) + new_pool.shutdown() + for previous in previous_pools: + previous.shutdown() + return True + + if signal_down: + self.cluster.on_down( + host, is_host_addition, expected_endpoint=creation_endpoint) + new_pool.shutdown() + return False + + if discard_pool: + log.debug("Discarding stale connection pool for host %s", host) + new_pool.shutdown() + return False log.debug("Added pool for host %s to session", host) - if previous: + for previous in previous_pools: previous.shutdown() return True - return self.submit(run_add_or_renew_pool) - - def remove_pool(self, host): - pool = self._pools.pop(host, None) - if pool: + with self._lock: + with host.lock: + creation_endpoint = host.endpoint + if (expected_endpoint is not None and + not self._endpoints_match( + creation_endpoint, expected_endpoint)): + return None + + state = self._get_pool_creation_state(host) + if state.creation_epoch is not None: + endpoint_changed = not self._endpoints_match( + creation_endpoint, state.endpoint) + if not endpoint_changed: + return state.future + self._invalidate_pool_creation( + host, expected_endpoint=state.endpoint) + + creation_epoch = state.advance() + state.creation_epoch = creation_epoch + state.endpoint = creation_endpoint + future = self.submit(run_add_or_renew_pool) + if future is None: + self._clear_pool_creation(host, creation_epoch) + return None + if state.event_is_current(_PoolCreationState._CREATE, creation_epoch): + state.future = future + return future + + def remove_pool(self, host, expected_host=None, expected_endpoint=None, + expected_pool=None): + removed_pools = [] + cleanup_context = _POOL_CLEANUP_EPOCH.get() + with self._lock: + with host.lock: + if cleanup_context is not None: + cleanup_host, cleanup_epoch = cleanup_context[:2] + allow_current_down = ( + len(cleanup_context) > 2 and cleanup_context[2]) + if cleanup_host is host: + # Stale cleanup can outlive an A->B->A flip; preserve the + # pool that belongs to the current epoch. + state = self._get_host_liveness_state(host) + cleanup_still_owns_down = ( + allow_current_down and + state.down_epoch == cleanup_epoch) + if (state.epoch != cleanup_epoch and + not cleanup_still_owns_down): + log.debug( + "Ignoring stale pool cleanup for host %s; epoch moved from %s to %s", + host, cleanup_epoch, state.epoch) + return None + + # Keep the endpoint snapshot stable until the removal pass + # finishes; migration can otherwise flip the branch mid-flight. + # Host hashes track endpoint, so endpoint swaps can hide stale + # entries behind normal dict lookups. Scan by identity instead. + remove_all = expected_endpoint is None + if expected_endpoint is not None: + remove_all = self._endpoints_match(host.endpoint, expected_endpoint) + + if expected_pool is None: + if remove_all: + self._invalidate_pool_creation(host) + else: + self._invalidate_pool_creation( + host, expected_endpoint=expected_endpoint) + + retained_pools = {} + for pool_host, host_pool in self._pools.items(): + if pool_host is not host: + retained_pools[pool_host] = host_pool + continue + + matches = self._pool_matches_expected( + host_pool, expected_host=expected_host, + expected_endpoint=None if remove_all else expected_endpoint, + expected_pool=expected_pool) + if matches: + removed_pools.append(host_pool) + else: + retained_pools[pool_host] = host_pool + self._pools = retained_pools + if removed_pools: log.debug("Removed connection pool for %r", host) - return self.submit(pool.shutdown) + return self.submit(self._shutdown_removed_pools, removed_pools) else: return None + @contextmanager + def _pool_cleanup_epoch(self, host, expected_epoch=None): + if expected_epoch is None: + yield + return + + token = _POOL_CLEANUP_EPOCH.set((host, expected_epoch)) + try: + yield + finally: + _POOL_CLEANUP_EPOCH.reset(token) + + @staticmethod + def _shutdown_removed_pools(pools): + for pool in pools: + pool.shutdown() + + def _pool_matches_expected(self, pool, expected_host=None, + expected_endpoint=None, expected_pool=None): + if expected_pool is not None and pool is not expected_pool: + return False + if expected_host is not None and pool.host is not expected_host: + return False + if expected_endpoint is not None: + pool_endpoint = getattr(pool, 'endpoint', None) + if pool_endpoint is None: + pool_endpoint = pool.host.endpoint + if not self._endpoints_match(pool_endpoint, expected_endpoint): + return False + return True + + @staticmethod + def _get_pool_endpoint(pool, default_endpoint): + endpoint = getattr(pool, 'endpoint', None) + connections = getattr(pool, '_connections', None) + if isinstance(connections, Mapping): + for connection in connections.values(): + if getattr(connection, 'original_endpoint', None) is None: + connection_endpoint = getattr(connection, 'endpoint', None) + if connection_endpoint is not None: + return connection_endpoint + elif endpoint is None: + endpoint = connection.original_endpoint + return endpoint if endpoint is not None else default_endpoint + + def _get_pool_by_host_identity(self, host, expected_host=None, expected_endpoint=None): + pool = self._pools.get(host) + if (pool is not None and pool.host is host and self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): + return pool + + with self._lock: + for pool_host, pool in self._pools.items(): + if (pool_host is host and self._pool_matches_expected( + pool, expected_host=expected_host, expected_endpoint=expected_endpoint)): + return pool + return None + + def _reuse_or_invalidate_pool_creation(self, host, pool_creation_future): + with self._lock: + current_distance = self._profile_manager.distance(host) + pool_creation_state = self._get_pool_creation_state(host) + if (pool_creation_state.creation_epoch is not None and + pool_creation_state.future is pool_creation_future): + with host.lock: + endpoint_changed = not self._endpoints_match(host.endpoint, pool_creation_state.endpoint) + if endpoint_changed: + self._invalidate_pool_creation( + host, expected_endpoint=pool_creation_state.endpoint) + if current_distance == HostDistance.IGNORED: + self._invalidate_pool_creation(host) + elif not endpoint_changed: + return pool_creation_future + return None + def update_created_pools(self): """ When the set of live nodes change, the loadbalancer will change its @@ -3309,14 +4438,36 @@ def update_created_pools(self): futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) - pool = self._pools.get(host) + with self._lock: + with host.lock: + host_endpoint = host.endpoint + host_is_up = host.is_up + pool = self._get_pool_by_host_identity( + host, expected_endpoint=host_endpoint) + any_pool = pool or self._get_pool_by_host_identity(host) + pool_creation_state = self._get_pool_creation_state(host) + pool_creation_future = ( + pool_creation_state.future + if pool_creation_state.creation_epoch is not None else None) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): - future = self.add_or_renew_pool(host, False) + if distance != HostDistance.IGNORED: + if pool_creation_future is not None: + # on_up() keeps host.is_up False until this future succeeds. + future = self._reuse_or_invalidate_pool_creation( + host, pool_creation_future) + elif host_is_up in (True, None): + future = self.add_or_renew_pool(host, False) + elif any_pool is not None: + future = self.remove_pool(host) + elif pool_creation_future is not None: + future = self._reuse_or_invalidate_pool_creation( + host, pool_creation_future) + elif any_pool is not None: + future = self.remove_pool(host) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: @@ -3327,12 +4478,12 @@ def update_created_pools(self): futures.add(future) return futures - def on_down(self, host): + def on_down(self, host, expected_endpoint=None): """ Called by the parent Cluster instance when a node is marked down. Only intended for internal use. """ - future = self.remove_pool(host) + future = self.remove_pool(host, expected_endpoint=expected_endpoint) if future: future.add_done_callback(lambda f: self.update_created_pools()) @@ -3881,23 +5032,36 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids.add(host_id) found_endpoints.add(endpoint) - host = self._cluster.metadata.get_host(endpoint) + host_by_endpoint = self._cluster.metadata.get_host(endpoint) + host_by_id = self._cluster.metadata.get_host_by_host_id(host_id) + host = host_by_id or host_by_endpoint datacenter = row.get("data_center") rack = row.get("rack") - if host is None: - host = self._cluster.metadata.get_host_by_host_id(host_id) - if host and host.endpoint != endpoint: - log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) - + if host is not None and not self._cluster._endpoints_match(host.endpoint, endpoint): + log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + with host.lock: old_endpoint = host.endpoint + self._cluster.profile_manager.on_down(host) + self.on_down(host) + self._cluster.on_down( + host, is_host_addition=False, expect_host_to_be_down=True, + expected_endpoint=old_endpoint, + profile_manager_already_notified=True, + control_connection_already_notified=True) + + with host.lock: + if not self._cluster._endpoints_match(host.endpoint, old_endpoint): + log.debug("[control connection] Not updating host ip from %s to %s for (%s); " + "endpoint changed to %s", + old_endpoint, endpoint, host_id, host.endpoint) + continue host.endpoint = endpoint self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) + self._cluster.on_up(host, expected_endpoint=endpoint) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) @@ -4018,13 +5182,15 @@ def _handle_status_change(self, event): change_type = event["change_type"] addr, port = event["address"] host = self._cluster.metadata.get_host(addr, port) + expected_endpoint = host.endpoint if host is not None else event["address"] if change_type == "UP": delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: - self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) + self._cluster.scheduler.schedule_unique( + delay, self._cluster.on_up, host, expected_endpoint=expected_endpoint) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. @@ -4032,7 +5198,7 @@ def _handle_status_change(self, event): # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler - self._cluster.on_down(host, is_host_addition=False) + self._cluster.on_down(host, is_host_addition=False, expected_endpoint=expected_endpoint) def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: """ @@ -4226,9 +5392,10 @@ def _signal_error(self): # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( - host, self._connection.last_error, is_host_addition=False) - return + if self._cluster.signal_connection_failure( + host, self._connection.last_error, is_host_addition=False, + expected_endpoint=self._connection.endpoint): + return # if the connection is not defunct or the host already left, reconnect # manually @@ -4289,6 +5456,7 @@ class _Scheduler(Thread): def __init__(self, executor): self._queue = queue.PriorityQueue() self._scheduled_tasks = set() + self._scheduled_tasks_lock = Lock() self._count = count() self._executor = executor @@ -4303,26 +5471,56 @@ def shutdown(self): # this can happen on interpreter shutdown pass self.is_shutdown = True - self._queue.put_nowait((0, 0, None)) + self._queue.put_nowait((0, 0, None, None)) self.join() def schedule(self, delay, fn, *args, **kwargs): - self._insert_task(delay, (fn, args, tuple(kwargs.items()))) + task = (fn, args, tuple(kwargs.items())) + self._insert_task(delay, self._task_key(fn, args, kwargs), task) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) - if task not in self._scheduled_tasks: - self._insert_task(delay, task) - else: - log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + task_key = self._task_key(fn, args, kwargs) + self._insert_task(delay, task_key, task, unique=True) + + @staticmethod + def _freeze_task_arg(value): + if isinstance(value, Host): + return (Host, id(value)) + if isinstance(value, EndPoint): + return (EndPoint, Cluster._endpoint_key(value)) + if isinstance(value, tuple): + return tuple(_Scheduler._freeze_task_arg(item) for item in value) + if isinstance(value, list): + return ('list', tuple(_Scheduler._freeze_task_arg(item) for item in value)) + if isinstance(value, Mapping): + return ( + 'dict', + tuple((key, _Scheduler._freeze_task_arg(val)) for key, val in value.items()) + ) + return value + + @classmethod + def _task_key(cls, fn, args, kwargs): + return ( + fn, + tuple(cls._freeze_task_arg(arg) for arg in args), + tuple((key, cls._freeze_task_arg(val)) for key, val in kwargs.items()), + ) + + def _insert_task(self, delay, task_key, task, unique=False): + with self._scheduled_tasks_lock: + if self.is_shutdown: + log.debug("Ignoring scheduled task after shutdown: %r", task) + return + + if unique and task_key in self._scheduled_tasks: + log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + return - def _insert_task(self, delay, task): - if not self.is_shutdown: run_at = time.time() + delay - self._scheduled_tasks.add(task) - self._queue.put_nowait((run_at, next(self._count), task)) - else: - log.debug("Ignoring scheduled task after shutdown: %r", task) + self._scheduled_tasks.add(task_key) + self._queue.put_nowait((run_at, next(self._count), task_key, task)) def run(self): while True: @@ -4331,19 +5529,20 @@ def run(self): try: while True: - run_at, i, task = self._queue.get(block=True, timeout=None) + run_at, i, task_key, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): - self._scheduled_tasks.discard(task) + with self._scheduled_tasks_lock: + self._scheduled_tasks.discard(task_key) fn, args, kwargs = task kwargs = dict(kwargs) future = self._executor.submit(fn, *args, **kwargs) future.add_done_callback(self._log_if_failed) else: - self._queue.put_nowait((run_at, i, task)) + self._queue.put_nowait((run_at, i, task_key, task)) break except queue.Empty: pass @@ -4428,6 +5627,7 @@ class ResponseFuture(object): _errbacks = None _current_host = None _connection = None + _connection_pool = None _query_retries = 0 _start_time = None _metrics = None @@ -4524,7 +5724,7 @@ def _on_timeout(self, _attempts=0): # Capture connection stats before pool.return_connection() can alter state conn_in_flight = self._connection.in_flight - pool = self.session._pools.get(self._current_host) + pool = self._connection_pool if pool and not pool.is_shutdown: # Do not return the stream ID to the pool yet. We cannot reuse it # because the node might still be processing the query and will @@ -4607,7 +5807,24 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message - pool = self.session._pools.get(host) + expected_endpoint = None + if isinstance(host, Host): + with host.lock: + expected_endpoint = host.endpoint + pool = self.session._get_pool_by_host_identity( + host, expected_endpoint=expected_endpoint) + else: + pool = self.session._get_pool_by_host_identity(host) + + if pool and expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None @@ -4624,7 +5841,25 @@ def _query(self, host, message=None, cb=None): connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table) else: connection, request_id = pool.borrow_connection(timeout=2.0) + + if expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + try: + with connection.lock: + connection.request_ids.append(request_id) + pool.return_connection(connection) + finally: + connection = None + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + self._connection = connection + self._connection_pool = pool + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..35cb107ef5 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -16,7 +16,7 @@ Connection pooling and host management. """ from concurrent.futures import Future -from functools import total_ordering +from functools import partial, total_ordering import logging import socket import time @@ -45,6 +45,54 @@ class NoConnectionsAvailable(Exception): pass +def _current_host_endpoint(host): + host_lock = getattr(host, "lock", None) + if host_lock is not None: + try: + with host_lock: + return host.endpoint + except (AttributeError, TypeError): + pass + return host.endpoint + + +def _endpoints_match(cluster, endpoint, expected_endpoint): + matches = cluster._endpoints_match(endpoint, expected_endpoint) + if isinstance(matches, bool): + return matches + # Mocked clusters can return a bare Mock here; fall back to equality. + return endpoint == expected_endpoint + + +def _host_is_current_for_endpoint(cluster, host, expected_endpoint): + metadata = getattr(cluster, "metadata", None) + hosts = getattr(metadata, "_hosts", None) + host_id_by_endpoint = getattr(metadata, "_host_id_by_endpoint", None) + if not isinstance(hosts, dict) or not isinstance(host_id_by_endpoint, dict): + return True + + def check_mapping(): + try: + mapped_host_id = host_id_by_endpoint.get(expected_endpoint) + except TypeError: + return True + if mapped_host_id is None: + return hosts.get(getattr(host, "host_id", None)) is host + return hosts.get(mapped_host_id) is host + + metadata_lock = getattr(metadata, "_hosts_lock", None) + if metadata_lock is not None: + with metadata_lock: + return check_mapping() + return check_mapping() + + +def _host_matches_expected_endpoint(cluster, host, expected_endpoint): + current_endpoint = _current_host_endpoint(host) + return (_endpoints_match(cluster, current_endpoint, expected_endpoint) and + _host_is_current_for_endpoint(cluster, host, expected_endpoint)) + + @total_ordering class Host(object): """ @@ -163,8 +211,6 @@ class Host(object): _reconnection_handler = None lock = None - _currently_handling_node_up = False - sharding_info = None def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): @@ -223,13 +269,15 @@ def signal_connection_failure(self, connection_exc): def is_currently_reconnecting(self): return self._reconnection_handler is not None - def get_and_set_reconnection_handler(self, new_handler): + def get_and_set_reconnection_handler(self, new_handler, expected_handler=None): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ with self.lock: old = self._reconnection_handler + if expected_handler is not None and old is not expected_handler: + return None self._reconnection_handler = new_handler return old @@ -351,16 +399,24 @@ def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *a self.on_up = on_up self.host = host self.connection_factory = connection_factory + self.callback_kwargs.setdefault('expected_handler', self) def try_reconnect(self): - return self.connection_factory() + connection_factory = self.connection_factory + endpoint = _current_host_endpoint(self.host) + if isinstance(connection_factory, partial): + return connection_factory.func( + endpoint, *connection_factory.args[1:], + **(connection_factory.keywords or {})) + return connection_factory() def on_reconnection(self, connection): log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) if self.is_host_addition: self.on_add(self.host) else: - self.on_up(self.host) + self.on_up(self.host, expected_endpoint=getattr(connection, "endpoint", None), + expected_reconnector=self) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): @@ -379,6 +435,7 @@ class HostConnection(object): """ host = None + endpoint = None host_distance = None is_shutdown = False shutdown_on_error = False @@ -393,8 +450,9 @@ class HostConnection(object): tablets_routing_v1 = False - def __init__(self, host, host_distance, session): + def __init__(self, host, host_distance, session, endpoint=None): self.host = host + self.endpoint = endpoint if endpoint is not None else host.endpoint self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() @@ -428,7 +486,7 @@ def __init__(self, host, host_distance, session): return log.debug("Initializing connection for host %s", self.host) - first_connection = session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + first_connection = session.cluster.connection_factory(self.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.features.shard_id) self._connections[first_connection.features.shard_id] = first_connection self._keyspace = session.keyspace @@ -551,18 +609,34 @@ def return_connection(self, connection, stream_was_orphaned=False): return is_down = False + stale_endpoint_failure = False if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.signal_connection_failure(connection.last_error) + is_down, stale_endpoint_failure = \ + self._signal_connection_failure_if_current(connection) connection.signaled_error = True + if stale_endpoint_failure: + # Drop only this stale pool; endpoint reuse may already belong to + # a replacement host instance. + future = self._session.remove_pool( + self.host, expected_host=self.host, + expected_endpoint=self.endpoint, expected_pool=self) + if future: + future.add_done_callback(lambda f: self._session.update_created_pools()) + with self._lock: + self._connections.pop(connection.features.shard_id, None) + connection.close() + return + if self.shutdown_on_error and not is_down: is_down = True if is_down: self.shutdown() - self._session.cluster.on_down(self.host, is_host_addition=False) + self._session.cluster.on_down(self.host, is_host_addition=False, + expected_endpoint=self.endpoint) else: connection.close() with self._lock: @@ -594,7 +668,76 @@ def on_orphaned_stream_released(self): with self._stream_available_condition: self._stream_available_condition.notify() + def _is_current_pool_for_endpoint_locked(self, expected_endpoint): + pools = getattr(self._session, "_pools", None) + if not isinstance(pools, dict): + return True + + for pool_host, host_pool in pools.items(): + if pool_host is not self.host or host_pool is not self: + continue + pool_endpoint = getattr(host_pool, 'endpoint', None) + if pool_endpoint is None: + pool_endpoint = host_pool.host.endpoint + return _endpoints_match( + self._session.cluster, pool_endpoint, expected_endpoint) + return False + + def _signal_connection_failure_if_current(self, connection): + session_lock = getattr(self._session, "_lock", None) + if session_lock is None: + with self.host.lock: + return self._signal_connection_failure_locked(connection) + + with session_lock: + with self.host.lock: + if not self._is_current_pool_for_endpoint_locked(self.endpoint): + log.debug("Ignoring stale connection failure for host %s; pool was replaced for %s", + self.host, self.endpoint) + return False, True + return self._signal_connection_failure_locked(connection) + + def _signal_connection_failure_locked(self, connection): + endpoint_matches = _endpoints_match( + self._session.cluster, self.host.endpoint, self.endpoint) + host_is_current = (endpoint_matches and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + self.endpoint)) + if not endpoint_matches: + log.debug("Ignoring stale connection failure for host %s; endpoint changed from %s", + self.host, self.endpoint) + return False, True + if not host_is_current: + log.debug("Ignoring stale connection failure for host %s; endpoint reassigned from %s", + self.host, self.endpoint) + return False, True + return self.host.signal_connection_failure(connection.last_error), False + + def _remove_stale_pool(self, expected_endpoint): + future = self._session.remove_pool( + self.host, expected_host=self.host, + expected_endpoint=expected_endpoint, expected_pool=self) + if future: + future.add_done_callback(lambda f: self._session.update_created_pools()) + def _replace(self, connection): + expected_endpoint = self.endpoint + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + self._remove_stale_pool(expected_endpoint) + with self._lock: + self._is_replacing = False + with self._stream_available_condition: + self._stream_available_condition.notify() + return + + direct_reconnect = not (self.host.sharding_info and + not self._session.cluster.shard_aware_options.disable) + replacement_connection = None + keyspace = None with self._lock: if self.is_shutdown: return @@ -603,22 +746,71 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if direct_reconnect: + keyspace = self._keyspace + replacement_connection = self._session.cluster.connection_factory( + expected_endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + else: self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) - else: - connection = self._session.cluster.connection_factory(self.host.endpoint, - on_orphaned_stream_released=self.on_orphaned_stream_released) - if self._keyspace: - connection.set_keyspace_blocking(self._keyspace) - self._connections[connection.features.shard_id] = connection except Exception: - log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) + log.warning("Failed reconnecting %s. Retrying." % (expected_endpoint,)) self._session.submit(self._replace, connection) - else: + return + + if not direct_reconnect: self._is_replacing = False with self._stream_available_condition: self._stream_available_condition.notify() + return + + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + replacement_connection.close() + self._remove_stale_pool(expected_endpoint) + with self._lock: + self._is_replacing = False + with self._stream_available_condition: + self._stream_available_condition.notify() + return + + if keyspace: + try: + replacement_connection.set_keyspace_blocking(keyspace) + except Exception: + log.warning("Failed reconnecting %s. Retrying." % (expected_endpoint,)) + replacement_connection.close() + self._session.submit(self._replace, connection) + return + + stale_endpoint = False + with self._lock: + if self.is_shutdown: + replacement_connection.close() + self._is_replacing = False + return + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, expected_endpoint)) + if not stale_endpoint: + self._connections[replacement_connection.features.shard_id] = replacement_connection + self._is_replacing = False + if stale_endpoint: + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + replacement_connection.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return + with self._stream_available_condition: + self._stream_available_condition.notify() def shutdown(self): log.debug("Shutting down connections to %s", self.host) @@ -676,17 +868,17 @@ def disable_advanced_shard_aware(self, secs): log.warning("disabling advanced_shard_aware for %i seconds, could be that this client is behind NAT?", secs) self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until) - def _get_shard_aware_endpoint(self): + def _get_shard_aware_endpoint(self, endpoint=None): + endpoint = self.endpoint if endpoint is None else endpoint if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ self._session.cluster.shard_aware_options.disable_shardaware_port: return None - endpoint = None if self._session.cluster.ssl_options and self.host.sharding_info.shard_aware_port_ssl: - endpoint = copy.copy(self.host.endpoint) + endpoint = copy.copy(endpoint) endpoint._port = self.host.sharding_info.shard_aware_port_ssl elif self.host.sharding_info.shard_aware_port: - endpoint = copy.copy(self.host.endpoint) + endpoint = copy.copy(endpoint) endpoint._port = self.host.sharding_info.shard_aware_port return endpoint @@ -709,140 +901,180 @@ def _open_connection_to_missing_shard(self, shard_id): the smaller the chance that further connections will be assigned to that shard. """ - with self._lock: - if self.is_shutdown: + expected_endpoint = self.endpoint + try: + with self._lock: + if self.is_shutdown: + return + + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() return - shard_aware_endpoint = self._get_shard_aware_endpoint() - log.debug("shard_aware_endpoint=%r", shard_aware_endpoint) - if shard_aware_endpoint: - try: - conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, - shard_id=shard_id, - total_shards=self.host.sharding_info.shards_count) - conn.original_endpoint = self.host.endpoint - except Exception as exc: - log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) - raise - else: - conn = self._session.cluster.connection_factory(self.host.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) - log.debug( - "Received a connection %s for shard_id=%i on host %s", - id(conn), - conn.features.shard_id if conn.features.shard_id is not None else -1, - self.host) - if self.is_shutdown: - log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn)) - conn.close() - return + shard_aware_endpoint = self._get_shard_aware_endpoint(expected_endpoint) + log.debug("shard_aware_endpoint=%r", shard_aware_endpoint) + if shard_aware_endpoint: + try: + conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, + shard_id=shard_id, + total_shards=self.host.sharding_info.shards_count) + conn.original_endpoint = expected_endpoint + except Exception as exc: + log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) + raise + else: + conn = self._session.cluster.connection_factory(expected_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) - if shard_aware_endpoint and shard_id != conn.features.shard_id: - # connection didn't land on expected shared - # assuming behind a NAT, disabling advanced shard aware for a while - self.disable_advanced_shard_aware(10 * 60) + if not _host_matches_expected_endpoint( + self._session.cluster, self.host, expected_endpoint): + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + conn.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return - old_conn = self._connections.get(conn.features.shard_id) - if old_conn is None or old_conn.orphaned_threshold_reached: log.debug( - "New connection (%s) created to shard_id=%i on host %s", + "Received a connection %s for shard_id=%i on host %s", id(conn), - conn.features.shard_id, - self.host - ) - old_conn = None - with self._lock: - is_shutdown = self.is_shutdown - if not is_shutdown: - if conn.features.shard_id in self._connections: - # Move the current connection to the trash and use the new one from now on - old_conn = self._connections[conn.features.shard_id] + conn.features.shard_id if conn.features.shard_id is not None else -1, + self.host) + if self.is_shutdown: + log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn)) + conn.close() + return + + if shard_aware_endpoint and shard_id != conn.features.shard_id: + # connection didn't land on expected shared + # assuming behind a NAT, disabling advanced shard aware for a while + self.disable_advanced_shard_aware(10 * 60) + + old_conn = self._connections.get(conn.features.shard_id) + if old_conn is None or old_conn.orphaned_threshold_reached: + log.debug( + "New connection (%s) created to shard_id=%i on host %s", + id(conn), + conn.features.shard_id, + self.host + ) + old_conn = None + with self._lock: + is_shutdown = self.is_shutdown + if not is_shutdown: + if conn.features.shard_id in self._connections: + # Move the current connection to the trash and use the new one from now on + old_conn = self._connections[conn.features.shard_id] + log.debug( + "Replacing overloaded connection (%s) with (%s) for shard %i for host %s", + id(old_conn), + id(conn), + conn.features.shard_id, + self.host + ) + if self._keyspace: + conn.set_keyspace_blocking(self._keyspace) + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + expected_endpoint)) + if not stale_endpoint: + self._connections[conn.features.shard_id] = conn + + if is_shutdown: + conn.close() + return + if stale_endpoint: + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + conn.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return + + if old_conn is not None: + remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) + if remaining == 0: log.debug( - "Replacing overloaded connection (%s) with (%s) for shard %i for host %s", + "Immediately closing the old connection (%s) for shard %i on host %s", id(old_conn), - id(conn), - conn.features.shard_id, + old_conn.features.shard_id, self.host ) - if self._keyspace: - conn.set_keyspace_blocking(self._keyspace) - self._connections[conn.features.shard_id] = conn - - if is_shutdown: - conn.close() - return - - if old_conn is not None: - remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) - if remaining == 0: - log.debug( - "Immediately closing the old connection (%s) for shard %i on host %s", - id(old_conn), - old_conn.features.shard_id, - self.host - ) - old_conn.close() - else: + old_conn.close() + else: + log.debug( + "Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining", + id(old_conn), + old_conn.features.shard_id, + self.host, + remaining, + ) + with self._lock: + is_shutdown = self.is_shutdown + if not is_shutdown: + self._trash.add(old_conn) + if is_shutdown: + conn.close() + num_missing_or_needing_replacement = self.num_missing_or_needing_replacement + log.debug( + "Connected to %s/%i shards on host %s (%i missing or needs replacement)", + len(self._connections), + self.host.sharding_info.shards_count, + self.host, + num_missing_or_needing_replacement + ) + if num_missing_or_needing_replacement == 0: log.debug( - "Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining", - id(old_conn), - old_conn.features.shard_id, + "All shards of host %s have at least one connection, closing %i excess connections", self.host, - remaining, + len(self._excess_connections) ) - with self._lock: - is_shutdown = self.is_shutdown - if not is_shutdown: - self._trash.add(old_conn) - if is_shutdown: - conn.close() - num_missing_or_needing_replacement = self.num_missing_or_needing_replacement - log.debug( - "Connected to %s/%i shards on host %s (%i missing or needs replacement)", - len(self._connections), - self.host.sharding_info.shards_count, - self.host, - num_missing_or_needing_replacement - ) - if num_missing_or_needing_replacement == 0: + self._close_excess_connections() + elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0: log.debug( - "All shards of host %s have at least one connection, closing %i excess connections", - self.host, - len(self._excess_connections) + "All shards are already covered, closing newly opened excess connection %s for host %s", + id(self), + self.host ) - self._close_excess_connections() - elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0: - log.debug( - "All shards are already covered, closing newly opened excess connection %s for host %s", - id(self), - self.host - ) - conn.close() - else: - if len(self._excess_connections) >= self._excess_connection_limit: + conn.close() + else: + if len(self._excess_connections) >= self._excess_connection_limit: + log.debug( + "After connection %s is created excess connection pool size limit (%i) reached for host %s, closing all %i of them", + id(conn), + self._excess_connection_limit, + self.host, + len(self._excess_connections) + ) + self._close_excess_connections() + log.debug( - "After connection %s is created excess connection pool size limit (%i) reached for host %s, closing all %i of them", + "Putting a connection %s to shard %i to the excess pool of host %s", id(conn), - self._excess_connection_limit, - self.host, - len(self._excess_connections) + conn.features.shard_id, + self.host ) - self._close_excess_connections() - - log.debug( - "Putting a connection %s to shard %i to the excess pool of host %s", - id(conn), - conn.features.shard_id, - self.host - ) - close_connection = False - with self._lock: - if self.is_shutdown: - close_connection = True - else: - self._excess_connections.add(conn) - if close_connection: - conn.close() - self._connecting.discard(shard_id) + close_connection = False + with self._lock: + if self.is_shutdown: + close_connection = True + else: + self._excess_connections.add(conn) + if close_connection: + conn.close() + finally: + self._connecting.discard(shard_id) def _open_connections_for_all_shards(self, skip_shard_id=None): """ @@ -867,7 +1099,7 @@ def _open_connections_for_all_shards(self, skip_shard_id=None): self._trash = set() if trash_conns is not None: - for conn in self._trash: + for conn in trash_conns: conn.close() def _set_keyspace_for_all_conns(self, keyspace, callback): @@ -920,5 +1152,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..70fdf81920 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -16,14 +16,18 @@ import logging import socket -from unittest.mock import patch, Mock +from concurrent.futures import Future +from threading import Event, Lock, RLock, Thread +from unittest.mock import patch, Mock, ANY import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.pool import Host +from cassandra.connection import ClientRoutesEndPoint, ConnectionException, DefaultEndPoint, SniEndPoint +from cassandra.metadata import Metadata +from cassandra.pool import Host, HostConnection, _HostReconnectionHandler from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools @@ -33,6 +37,37 @@ log = logging.getLogger(__name__) + +class _ImmediateExecutor(object): + + def submit(self, fn, *args, **kwargs): + future = Future() + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + +class _QueuedExecutor(object): + + def __init__(self): + self.submissions = [] + + def submit(self, fn, *args, **kwargs): + future = Future() + self.submissions.append((future, fn, args, kwargs)) + return future + + def run_next(self, index=0): + future, fn, args, kwargs = self.submissions.pop(index) + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -229,6 +264,27 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_reconnector_connection_factory_recomputes_authenticator_after_endpoint_swap(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + auth_provider = Mock() + auth_provider.new_authenticator.side_effect = lambda address: 'auth-%s' % (address,) + + with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: + cluster = Cluster(auth_provider=auth_provider) + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + connection_factory = cluster._make_connection_factory(host) + handler = _HostReconnectionHandler( + host, connection_factory, False, Mock(), Mock(), Mock(), iter([0]), + host.get_and_set_reconnection_handler, new_handler=None) + + host.endpoint = new_endpoint + conn = handler.try_reconnect() + + assert conn == 'connection' + assert factory.call_args.args[0] == new_endpoint + assert factory.call_args.kwargs['authenticator'] == 'auth-127.0.0.2' + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -245,6 +301,2229 @@ def test_event_delay_timing(self, *_): sched.schedule(0, lambda: None) sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" + @patch('cassandra.cluster._Scheduler.run') # don't actually run the thread + def test_schedule_unique_keeps_client_route_events_for_distinct_ports(self, *_): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host = Host(new_endpoint, SimpleConvictionPolicy, host_id=host_id) + scheduled_fn = Mock() + + sched = _Scheduler(Mock()) + sched.schedule_unique( + 30, scheduled_fn, host, expected_endpoint=old_endpoint) + sched.schedule_unique( + 30, scheduled_fn, host, expected_endpoint=new_endpoint) + + assert len(sched._scheduled_tasks) == 2 + + +class SessionPoolRaceTest(unittest.TestCase): + + class _EndpointSwapIfPoolUnpublishedOnFirstExitLock(object): + + def __init__(self, host, new_endpoint, pool_is_published): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._pool_is_published = pool_is_published + self._exits = 0 + self.pool_was_published_on_first_exit = None + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self.pool_was_published_on_first_exit = self._pool_is_published() + if not self.pool_was_published_on_first_exit: + self._host.endpoint = self._new_endpoint + + class _DuplicatePoolEntries(object): + + def __init__(self, entries): + self._entries = entries + + def __len__(self): + return len(self._entries) + + def items(self): + return list(self._entries) + + @staticmethod + def _make_host(address): + return Host(address, SimpleConvictionPolicy, host_id=uuid.uuid4()) + + @staticmethod + def _make_pool(host, distance, session, endpoint=None): + pool = Mock() + pool.host = host + pool.endpoint = endpoint if endpoint is not None else host.endpoint + pool.host_distance = distance + pool._keyspace = session.keyspace + pool.is_shutdown = False + return pool + + @staticmethod + def _make_cluster_and_session(hosts): + executor = _QueuedExecutor() + + cluster = Cluster.__new__(Cluster) + cluster.is_shutdown = False + cluster.profile_manager = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.control_connection = Mock() + cluster.metadata = Mock() + cluster.metadata.all_hosts.return_value = hosts + cluster._listeners = set() + cluster._listener_lock = Lock() + cluster.executor = executor + cluster._prepare_all_queries = Mock() + + session = Session.__new__(Session) + session.cluster = cluster + session.is_shutdown = False + session.keyspace = None + session._lock = RLock() + session._pools = {} + session._profile_manager = cluster.profile_manager + cluster.sessions = set([session]) + + return cluster, session, executor + + def test_update_created_pools_reuses_in_flight_add_for_issue_317(self): + first_host = self._make_host("127.0.0.1") + second_host = self._make_host("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session( + [first_host, second_host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + Cluster.on_add(cluster, first_host) + Cluster.on_add(cluster, second_host) + + assert len(executor.submissions) == 2 + + executor.run_next(1) + + assert second_host.is_up + assert len(executor.submissions) == 1 + + executor.run_next() + + assert len(created_pools) == 2 + assert session._pools[first_host].host is first_host + assert session._pools[second_host].host is second_host + for pool in created_pools: + pool.shutdown.assert_not_called() + + def test_add_or_renew_pool_returns_in_flight_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + + with patch("cassandra.cluster.HostConnection", + side_effect=self._make_pool): + first_future = session.add_or_renew_pool(host, is_host_addition=True) + second_future = session.add_or_renew_pool(host, is_host_addition=False) + + assert second_future is first_future + assert len(executor.submissions) == 1 + + executor.run_next() + + assert session._pools[host].host is host + + def test_update_created_pools_keeps_in_flight_up_pool_for_down_host(self): + host = self._make_host("127.0.0.1") + host.set_down() + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert session.update_created_pools() == set([future]) + + executor.run_next() + + assert future.result() is True + assert session._pools[host].host is host + created_pools[0].shutdown.assert_not_called() + + def test_update_created_pools_rechecks_distance_before_invalidating_pool_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + distance_calls = [] + future_holder = {} + + def distance(changed_host): + distance_calls.append(changed_host) + if len(distance_calls) == 1: + future_holder["future"] = session.add_or_renew_pool( + host, is_host_addition=False) + return HostDistance.IGNORED + return HostDistance.LOCAL + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + cluster.profile_manager.distance.side_effect = distance + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + assert session.update_created_pools() == set([future_holder["future"]]) + + executor.run_next() + + assert future_holder["future"].result() is True + assert session._pools[host].host is host + created_pools[0].shutdown.assert_not_called() + + def test_update_created_pools_rechecks_distance_before_reusing_pool_creation(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + cluster.profile_manager.distance.side_effect = [ + HostDistance.LOCAL, HostDistance.IGNORED] + + assert session.update_created_pools() == set() + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_update_created_pools_invalidates_creation_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + host.endpoint = DefaultEndPoint("127.0.0.2") + + assert session.update_created_pools() == set() + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_add_or_renew_pool_invalidates_creation_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + _, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + old_future = session.add_or_renew_pool( + host, is_host_addition=False) + host.endpoint = new_endpoint + + new_future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert new_future is not old_future + assert len(executor.submissions) == 2 + + executor.run_next() + executor.run_next() + + assert old_future.result() is False + assert new_future.result() is True + assert created_pools[0].endpoint == old_endpoint + assert created_pools[1].endpoint == new_endpoint + created_pools[0].shutdown.assert_called_once_with() + created_pools[1].shutdown.assert_not_called() + assert session._pools[host] is created_pools[1] + + def test_on_up_does_not_publish_replacement_endpoint_pool_after_endpoint_swap(self): + host = self._make_host("127.0.0.1") + host.set_down() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + original_add_or_renew_pool = Session.add_or_renew_pool.__get__( + session, Session) + + def add_after_endpoint_swap(host, *args, **kwargs): + host.endpoint = new_endpoint + return original_add_or_renew_pool(host, *args, **kwargs) + + session.add_or_renew_pool = Mock(side_effect=add_after_endpoint_swap) + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + Cluster.on_up(cluster, host) + while executor.submissions: + executor.run_next() + + assert created_pools == [] + assert session._pools == {} + session.add_or_renew_pool.assert_called_once_with( + host, is_host_addition=False, + allow_retry_after_auth_failure=True, + expected_endpoint=old_endpoint) + + def test_pool_creation_publishes_before_endpoint_lock_is_released(self): + host = self._make_host("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + host.lock = self._EndpointSwapIfPoolUnpublishedOnFirstExitLock( + host, new_endpoint, lambda: bool(session._pools)) + + executor.run_next() + + assert host.lock.pool_was_published_on_first_exit is True + assert future.result() is True + assert session._pools[host] is created_pools[0] + created_pools[0].shutdown.assert_not_called() + + def test_update_created_pools_removes_stale_pool_for_down_host_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + host.set_down() + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = stale_pool + + host.endpoint = DefaultEndPoint("127.0.0.2") + + futures = session.update_created_pools() + + assert len(futures) == 1 + executor.run_next() + assert session._pools == {} + stale_pool.shutdown.assert_called_once_with() + + def test_update_created_pools_replaces_pool_after_endpoint_change(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + session._pools[host] = stale_pool + created_pools = [] + + host.endpoint = DefaultEndPoint("127.0.0.2") + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + futures = session.update_created_pools() + + assert len(futures) == 1 + executor.run_next() + + assert created_pools[0].endpoint == host.endpoint + assert session._pools[host] is created_pools[0] + stale_pool.shutdown.assert_called_once_with() + + def test_removed_in_flight_pool_is_not_published(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_stale_host_pool_creation_does_not_publish_to_replacement_host(self): + host_id = uuid.uuid4() + stale_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, + host_id=host_id) + replacement_host = Host(DefaultEndPoint("127.0.0.2"), + SimpleConvictionPolicy, host_id=host_id) + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + cluster.metadata = Metadata() + cluster.metadata.add_or_return_host(replacement_host) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + created_pools[0].shutdown.assert_called_once_with() + + def test_stale_host_pool_creation_does_not_replace_same_endpoint_host(self): + endpoint = DefaultEndPoint("127.0.0.1") + stale_host = Host(endpoint, SimpleConvictionPolicy, + host_id=uuid.uuid4()) + replacement_host = Host(endpoint, SimpleConvictionPolicy, + host_id=uuid.uuid4()) + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + cluster.metadata = Metadata() + cluster.metadata.add_or_return_host(replacement_host) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + session._pools[replacement_host] = replacement_pool + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + executor.run_next() + + assert future.result() is False + assert session._pools[replacement_host] is replacement_pool + replacement_pool.shutdown.assert_not_called() + created_pools[0].shutdown.assert_called_once_with() + + def test_remove_pool_expected_host_mismatch_invalidates_stale_creation(self): + stale_host = self._make_host("127.0.0.1") + replacement_host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + created_pools = [] + session._pools[replacement_host] = replacement_pool + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + stale_host, is_host_addition=False) + + assert session.remove_pool( + stale_host, expected_host=stale_host) is None + + executor.run_next() + + assert future.result() is False + assert session._pools[replacement_host] is replacement_pool + replacement_pool.shutdown.assert_not_called() + created_pools[0].shutdown.assert_called_once_with() + + def test_remove_pool_finds_pool_after_host_endpoint_changes(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools[host] = pool + + host.endpoint = DefaultEndPoint("127.0.0.2") + + future = session.remove_pool(host, expected_host=host) + executor.run_next() + + assert session._pools == {} + pool.shutdown.assert_called_once_with() + assert future.done() + + def test_remove_pool_prefers_identity_after_endpoint_rewrite(self): + stale_host = self._make_host("127.0.0.1") + replacement_host = self._make_host("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session( + [replacement_host]) + stale_pool = self._make_pool(stale_host, HostDistance.LOCAL, session) + replacement_pool = self._make_pool( + replacement_host, HostDistance.LOCAL, session) + session._pools[stale_host] = stale_pool + + stale_host.endpoint = replacement_host.endpoint + session._pools[replacement_host] = replacement_pool + + assert stale_host == replacement_host + + session.remove_pool(stale_host, expected_host=stale_host) + executor.run_next() + + assert session._pools[replacement_host] is replacement_pool + stale_pool.shutdown.assert_called_once_with() + replacement_pool.shutdown.assert_not_called() + + def test_remove_pool_expected_endpoint_preserves_replacement_pool(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + stale_pool = self._make_pool(host, HostDistance.LOCAL, session) + + host.endpoint = DefaultEndPoint("127.0.0.2") + replacement_pool = self._make_pool(host, HostDistance.LOCAL, session) + session._pools = self._DuplicatePoolEntries([ + (host, stale_pool), + (host, replacement_pool), + ]) + + assert len(session._pools) == 2 + + session.remove_pool( + host, expected_host=host, expected_endpoint=old_endpoint) + executor.run_next() + + assert session._pools[host] is replacement_pool + stale_pool.shutdown.assert_called_once_with() + replacement_pool.shutdown.assert_not_called() + + def test_remove_pool_expected_endpoint_preserves_replacement_creation(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + host.endpoint = DefaultEndPoint("127.0.0.2") + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool( + host, is_host_addition=False) + + assert session.remove_pool( + host, expected_host=host, + expected_endpoint=old_endpoint) is None + + executor.run_next() + + assert future.result() is True + assert session._pools[host] is created_pools[0] + assert created_pools[0].endpoint == host.endpoint + created_pools[0].shutdown.assert_not_called() + + def test_stale_host_connection_cleanup_after_endpoint_flip_back_preserves_current_pool(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + replacement_endpoint = DefaultEndPoint("127.0.0.2") + cluster, session, executor = self._make_cluster_and_session([host]) + + host.endpoint = replacement_endpoint + host.endpoint = old_endpoint + current_pool = self._make_pool( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + session._pools[host] = current_pool + + stale_pool = HostConnection.__new__(HostConnection) + stale_pool.host = host + stale_pool.endpoint = old_endpoint + stale_pool._session = session + + stale_pool._remove_stale_pool(old_endpoint) + + assert session._get_pool_by_host_identity(host) is current_pool + current_pool.shutdown.assert_not_called() + + def test_add_or_renew_pool_tags_pool_with_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + pool = self._make_pool(host, distance, pool_session, endpoint) + created_pools.append(pool) + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool) as host_connection: + future = session.add_or_renew_pool( + host, is_host_addition=False) + + executor.run_next() + + assert future.result() is False + host_connection.assert_called_once_with( + host, HostDistance.LOCAL, session, endpoint=old_endpoint) + assert created_pools[0].endpoint == old_endpoint + created_pools[0].shutdown.assert_called_once_with() + + def test_add_or_renew_pool_auth_failure_reports_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + def fail_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + raise AuthenticationFailed("failed") + + with patch("cassandra.cluster.HostConnection", side_effect=fail_pool): + future = session.add_or_renew_pool(host, is_host_addition=False) + executor.run_next() + + assert future.result() is False + args, kwargs = cluster.signal_connection_failure.call_args + conn_exc = args[1] + assert isinstance(conn_exc, ConnectionException) + assert conn_exc.endpoint == old_endpoint + assert isinstance(conn_exc.__cause__, AuthenticationFailed) + assert args == (host, conn_exc, False) + assert kwargs == {"expected_endpoint": old_endpoint} + + def test_add_or_renew_pool_failure_reports_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + pool_error = RuntimeError("failed") + + def fail_pool(host, distance, pool_session, endpoint=None): + host.endpoint = DefaultEndPoint("127.0.0.2") + raise pool_error + + with patch("cassandra.cluster.HostConnection", side_effect=fail_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + executor.run_next() + + assert future.result() is False + args, kwargs = cluster.signal_connection_failure.call_args + conn_exc = args[1] + assert conn_exc is pool_error + assert conn_exc.endpoint == old_endpoint + assert args == (host, conn_exc, True) + assert kwargs == { + "expect_host_to_be_down": True, + "expected_endpoint": old_endpoint + } + + def test_removed_in_flight_pool_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + with patch("cassandra.cluster.HostConnection", + side_effect=ConnectionException("failed")): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + cluster.signal_connection_failure.assert_not_called() + + def test_removed_in_flight_pool_auth_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.signal_connection_failure = Mock() + + with patch("cassandra.cluster.HostConnection", + side_effect=AuthenticationFailed("failed")): + future = session.add_or_renew_pool(host, is_host_addition=True) + session.remove_pool(host) + + executor.run_next() + + assert future.result() is False + cluster.signal_connection_failure.assert_not_called() + + def test_stale_keyspace_failure_does_not_signal_down(self): + host = self._make_host("127.0.0.1") + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.connect_timeout = 1 + cluster.on_down = Mock() + session.keyspace = "ks" + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + pool._keyspace = None + + def set_keyspace(keyspace, callback): + session.remove_pool(host) + callback(pool, [Exception("failed")]) + + pool._set_keyspace_for_all_conns.side_effect = set_keyspace + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + cluster.on_down.assert_not_called() + + def test_keyspace_failure_signals_down_for_creation_endpoint(self): + host = self._make_host("127.0.0.1") + old_endpoint = host.endpoint + cluster, session, executor = self._make_cluster_and_session([host]) + cluster.connect_timeout = 1 + cluster.on_down = Mock() + session.keyspace = "ks" + created_pools = [] + + def make_pool(host, distance, pool_session, endpoint=None): + pool = self._make_pool(host, distance, pool_session, endpoint) + pool._keyspace = None + created_pools.append(pool) + + def set_keyspace(keyspace, callback): + host.endpoint = DefaultEndPoint("127.0.0.2") + callback(pool, [Exception("failed")]) + + pool._set_keyspace_for_all_conns.side_effect = set_keyspace + return pool + + with patch("cassandra.cluster.HostConnection", side_effect=make_pool): + future = session.add_or_renew_pool(host, is_host_addition=True) + + executor.run_next() + + assert future.result() is False + assert session._pools == {} + cluster.on_down.assert_called_once_with( + host, True, expected_endpoint=old_endpoint) + created_pools[0].shutdown.assert_called_once_with() + + +class HostStateRaceTest(unittest.TestCase): + + class _EndpointSwapOnFirstExitLock(object): + + def __init__(self, host, new_endpoint): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._exits = 0 + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self._host.endpoint = self._new_endpoint + + + @staticmethod + def _make_host(): + return Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + @staticmethod + def _make_cluster(session=None, listener=None): + cluster = Cluster.__new__(Cluster) + cluster.is_shutdown = False + cluster.profile_manager = Mock() + cluster.control_connection = Mock() + cluster.sessions = set([session] if session else []) + cluster._listeners = set([listener] if listener else []) + cluster._listener_lock = Lock() + cluster.executor = _ImmediateExecutor() + cluster._start_reconnector = Mock() + cluster._discount_down_events = False + return cluster + + @staticmethod + def _make_session_with_pool(host, pool): + session = Session.__new__(Session) + session._lock = Lock() + session._pools = {host: pool} + session.cluster = Mock() + session.cluster.metadata.all_hosts.return_value = [] + session.submit = _ImmediateExecutor().submit + return session + + def test_discount_down_event_does_not_hold_host_lock_while_scanning_pools(self): + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = old_endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + # Mutating the endpoint after insertion forces an identity scan under + # session._lock; the down path must not hold host.lock during that scan. + host.endpoint = DefaultEndPoint("127.0.0.2") + entered_distance = Event() + release_distance = Event() + blocked_on_host = [] + thread_errors = [] + + def distance(_host): + entered_distance.set() + release_distance.wait(1) + return HostDistance.LOCAL + + def hold_session_then_try_host(): + with session._lock: + entered_distance.wait(1) + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + release_distance.set() + + def run_on_down(): + try: + Cluster.on_down(cluster, host, is_host_addition=False) + except Exception as exc: + thread_errors.append(exc) + + cluster.profile_manager.distance.side_effect = distance + worker = Thread(target=hold_session_then_try_host) + runner = Thread(target=run_on_down) + worker.start() + runner.start() + worker.join(2) + runner.join(2) + + assert not thread_errors + assert not runner.is_alive() + assert blocked_on_host == [False] + + def test_signal_connection_failure_does_not_hold_host_lock_while_scanning_pools(self): + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = old_endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + # Mutating the endpoint after insertion forces an identity scan under + # session._lock. signal_connection_failure() must not keep host.lock + # held while it delegates to the down path that performs that scan. + host.endpoint = DefaultEndPoint("127.0.0.2") + entered_distance = Event() + release_distance = Event() + blocked_on_host = [] + thread_errors = [] + + def distance(_host): + entered_distance.set() + release_distance.wait(1) + return HostDistance.LOCAL + + def hold_session_then_try_host(): + with session._lock: + entered_distance.wait(1) + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + release_distance.set() + + def run_signal_connection_failure(): + try: + Cluster.signal_connection_failure( + cluster, host, ConnectionException("failed"), + is_host_addition=False) + except Exception as exc: + thread_errors.append(exc) + + cluster.profile_manager.distance.side_effect = distance + worker = Thread(target=hold_session_then_try_host) + runner = Thread(target=run_signal_connection_failure) + worker.start() + runner.start() + worker.join(2) + runner.join(2) + + assert not thread_errors + assert not runner.is_alive() + assert blocked_on_host == [False] + + def test_discount_down_event_applies_to_current_expected_endpoint(self): + host = self._make_host() + host.set_up() + endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=endpoint) + + assert host.is_up + cluster.on_down_potentially_blocking.assert_not_called() + + def test_forced_down_is_not_discounted_by_connected_pool(self): + host = self._make_host() + host.set_up() + endpoint = host.endpoint + pool = Mock() + pool.host = host + pool.endpoint = endpoint + pool.open_count = 1 + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + cluster._discount_down_events = True + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + cluster.on_down_potentially_blocking = Mock(return_value=None) + session.cluster = cluster + + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True, expected_endpoint=endpoint) + + assert not host.is_up + cluster.on_down_potentially_blocking.assert_called_once_with( + host, False, ANY, endpoint, False, False) + + @staticmethod + def _state(cluster, host): + return cluster._get_host_liveness_state(host) + + @classmethod + def _reserve_down_handling(cls, cluster, host): + with host.lock: + state = cls._state(cluster, host) + state.epoch += 1 + state.pending_up_epoch = None + host.set_down() + state.down_epoch = state.epoch + return state.down_epoch + + def test_stale_down_handling_is_ignored_after_host_is_up(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + with host.lock: + state = self._state(cluster, host) + state.down_epoch = None + state.epoch += 1 + host.set_up() + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_stale_generic_down_handling_uses_original_endpoint_after_endpoint_swap(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down(cluster, host, is_host_addition=False) + host.endpoint = new_endpoint + + executor.run_next() + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_unreserved_down_handling_is_ignored_during_host_up_handling(self): + session = Mock() + cluster = self._make_cluster(session=session) + host = self._make_host() + host.set_down() + self._state(cluster, host).up_epoch = 0 + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False) + + session.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_reserved_down_handling_after_endpoint_swap_only_removes_stale_pool(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + down_epoch = self._reserve_down_handling(cluster, host) + + host.endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch, + expected_endpoint=old_endpoint) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_reserved_down_handling_after_endpoint_swap_removes_stale_pool(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host.endpoint = old_endpoint + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + host.lock = self._EndpointSwapOnFirstExitLock(host, new_endpoint) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_expected_endpoint_down_listener_is_not_called_under_host_lock(self): + session = Mock() + cluster = self._make_cluster(session=session) + host = self._make_host() + host.set_up() + expected_endpoint = host.endpoint + down_epoch = self._reserve_down_handling(cluster, host) + blocked_on_host = [] + + class Listener(object): + + def on_down(self, _host): + def try_host_lock(): + acquired = host.lock.acquire(timeout=0.2) + blocked_on_host.append(not acquired) + if acquired: + host.lock.release() + + worker = Thread(target=try_host_lock) + worker.start() + worker.join(1) + + cluster._listeners = set([Listener()]) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch, + expected_endpoint=expected_endpoint) + + self.assertEqual(blocked_on_host, [False]) + + def test_endpoint_match_preserves_endpoint_specific_identity(self): + proxy_endpoint = SniEndPoint("proxy.example.com", "node-a", port=9042) + other_proxy_endpoint = SniEndPoint("proxy.example.com", "node-b", port=9042) + assert not Cluster._endpoints_match(proxy_endpoint, other_proxy_endpoint) + + old_endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + def test_auth_failure_quarantine_preserves_endpoint_specific_identity(self): + host = self._make_host() + host.endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + + Cluster._set_non_retryable_auth_failure(host, True) + host.endpoint = ClientRoutesEndPoint( + uuid.uuid4(), Mock(), "127.0.0.1", original_port=9042) + + assert not Cluster._has_non_retryable_auth_failure(host) + + def test_noop_down_during_up_handling_does_not_supersede_up(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + up_epoch = state.up_epoch + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert state.epoch == up_epoch + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + assert host.is_up + assert state.up_epoch is None + + def test_on_up_waits_for_all_pool_futures_when_one_is_already_done(self): + completed_pool_future = Future() + completed_pool_future.set_result(True) + pending_pool_future = Future() + completed_session = Mock() + completed_session.add_or_renew_pool.return_value = completed_pool_future + pending_session = Mock() + pending_session.add_or_renew_pool.return_value = pending_pool_future + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.sessions = [completed_session, pending_session] + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + + assert not host.is_up + listener.on_up.assert_not_called() + + pending_pool_future.set_result(False) + + assert not host.is_up + listener.on_up.assert_not_called() + cluster._start_reconnector.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=host.endpoint) + + def test_newer_forced_down_during_up_handling_is_preserved(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + assert session.remove_pool.call_count == 1 + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + assert state.epoch > first_up_epoch + assert state.up_epoch == first_up_epoch + assert not host.is_up + + pool_future.set_result(True) + + listener.on_up.assert_not_called() + assert session.remove_pool.call_count == 1 + assert not host.is_up + assert state.up_epoch is None + + def test_stale_failed_up_callback_does_not_cleanup_newer_down(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + pool_future.set_exception(RuntimeError("pool failed after newer down")) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + listener.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + + def test_failed_up_callback_rechecks_before_cleanup(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + + def force_down_before_cleanup(message, *args, **kwargs): + if message.startswith("Connection pool could not be created"): + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True) + + with patch("cassandra.cluster.log.debug", side_effect=force_down_before_cleanup): + pool_future.set_result(False) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + assert session.remove_pool.call_count == 1 + listener.on_up.assert_not_called() + assert not host.is_up + assert state.up_epoch is None + assert state.down_epoch is None + + def test_failed_up_callback_after_endpoint_swap_does_not_signal_down(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + old_endpoint = host.endpoint + + host.endpoint = DefaultEndPoint("127.0.0.2") + pool_future.set_exception(RuntimeError("pool failed after endpoint swap")) + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert session.remove_pool.call_count == 2 + session.remove_pool.assert_any_call( + host, expected_host=host, expected_endpoint=old_endpoint) + assert not host.is_up + assert state.up_epoch is None + + def test_forced_down_during_up_handling_is_not_hidden_by_reconnector(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_down() + old_reconnector = Mock() + host._reconnection_handler = old_reconnector + original_get_reconnector = Cluster._get_reconnector_for_current_up_handling + + def force_down_before_reconnector_is_cleared(h, up_epoch, **kwargs): + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + return original_get_reconnector(cluster, h, up_epoch, **kwargs) + + cluster._get_reconnector_for_current_up_handling = Mock( + side_effect=force_down_before_reconnector_is_cleared) + + Cluster.on_up(cluster, host) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + old_reconnector.cancel.assert_called_once_with() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_forced_down_while_reconnecting_runs_new_down_handling(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_down() + host._reconnection_handler = Mock() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + assert self._state(cluster, host).down_epoch is None + + def test_newer_down_before_up_side_effects_suppresses_stale_up(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + original_superseded_check = Cluster._up_handling_was_superseded + checked = [] + + def force_down_before_first_superseded_check(h, up_epoch): + if not checked: + checked.append(True) + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + return original_superseded_check(cluster, h, up_epoch) + + cluster._up_handling_was_superseded = Mock( + side_effect=force_down_before_first_superseded_check) + + Cluster.on_up(cluster, host) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_up_during_down_superseding_in_flight_up_is_replayed(self): + first_pool_future = Future() + second_pool_future = Future() + session = Mock() + session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + assert state.epoch > first_up_epoch + assert state.up_epoch == first_up_epoch + assert state.pending_up_epoch == state.epoch + + first_pool_future.set_result(True) + + assert session.add_or_renew_pool.call_count == 2 + assert state.up_epoch == state.epoch + assert state.pending_up_epoch is None + listener.on_up.assert_not_called() + + second_pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + assert host.is_up + assert state.up_epoch is None + + def test_stale_superseded_up_cleanup_does_not_run_after_newer_down(self): + first_pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = first_pool_future + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + Cluster.on_up(cluster, host) + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + cleanup_calls = [] + + def signal_up_during_stale_cleanup(h, **kwargs): + cleanup_calls.append(h) + return None + + session.remove_pool.side_effect = signal_up_during_stale_cleanup + + first_pool_future.set_result(True) + + assert cleanup_calls == [] + assert session.add_or_renew_pool.call_count == 1 + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + assert not host.is_up + + def test_sync_up_failure_replays_queued_up(self): + session = Mock() + session.add_or_renew_pool.return_value = None + cluster = self._make_cluster(session=session) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + fail_first_on_up = [True] + + def queue_up_then_fail(h): + if not fail_first_on_up[0]: + return + fail_first_on_up[0] = False + Cluster.on_down( + cluster, h, is_host_addition=False, expect_host_to_be_down=True) + Cluster.on_up(cluster, h) + state = self._state(cluster, h) + assert state.pending_up_epoch == state.epoch + raise RuntimeError("up failed") + + cluster.profile_manager.on_up.side_effect = queue_up_then_fail + + with pytest.raises(RuntimeError): + Cluster.on_up(cluster, host) + + assert cluster.profile_manager.on_up.call_count == 2 + cluster.control_connection.on_up.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with( + host, is_host_addition=False, + allow_retry_after_auth_failure=True, + expected_endpoint=host.endpoint) + assert host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_old_up_callback_does_not_clear_replayed_up_handling(self): + first_pool_future = Future() + second_pool_future = Future() + session = Mock() + session.add_or_renew_pool.side_effect = [first_pool_future, second_pool_future] + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + + listener.on_up.side_effect = lambda h: Cluster.on_down( + cluster, h, is_host_addition=False) + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_up(cluster, host) + state = self._state(cluster, host) + first_up_epoch = state.up_epoch + + first_pool_future.set_result(True) + + assert session.add_or_renew_pool.call_count == 2 + assert not host.is_up + assert state.up_epoch != first_up_epoch + + listener.on_up.side_effect = None + second_pool_future.set_result(True) + + assert listener.on_up.call_count == 2 + assert host.is_up + assert state.up_epoch is None + + def test_stale_reconnector_success_does_not_clear_newer_reconnector(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + cluster = self._make_cluster() + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.endpoint = new_endpoint + new_reconnector = Mock() + host._reconnection_handler = new_reconnector + connection = Mock(endpoint=old_endpoint) + handler = _HostReconnectionHandler( + host, Mock(return_value=connection), False, Mock(), cluster.on_up, Mock(), + iter([0]), host.get_and_set_reconnection_handler, new_handler=None) + + handler.run() + + assert host._reconnection_handler is new_reconnector + assert not host.is_up + + def test_superseded_up_cleanup_preserves_replacement_host_pool(self): + stale_host = self._make_host() + replacement_host = self._make_host() + replacement_pool = Mock(host=replacement_host) + session = self._make_session_with_pool(replacement_host, replacement_pool) + cluster = self._make_cluster(session=session) + + assert stale_host == replacement_host + assert stale_host.host_id != replacement_host.host_id + + cluster._cleanup_superseded_up_handling(stale_host) + + assert session._pools.get(replacement_host) is replacement_pool + replacement_pool.shutdown.assert_not_called() + + def test_superseded_up_cleanup_removes_matching_host_pool(self): + host = self._make_host() + pool = Mock(host=host) + session = self._make_session_with_pool(host, pool) + cluster = self._make_cluster(session=session) + + cluster._cleanup_superseded_up_handling(host) + + assert session._pools == {} + pool.shutdown.assert_called_once_with() + + def test_down_during_up_listener_is_handled(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_down() + listener.on_up.side_effect = lambda h: Cluster.on_down( + cluster, h, is_host_addition=False) + + Cluster.on_up(cluster, host) + assert self._state(cluster, host).up_epoch is not None + + pool_future.set_result(True) + + listener.on_up.assert_called_once_with(host) + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).down_epoch is None + + def test_current_down_handling_still_removes_pools_and_reconnects(self): + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + assert self._state(cluster, host).down_epoch is None + + def test_remove_during_down_listener_does_not_start_reconnector(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.metadata = Mock() + cluster.metadata.remove_host.return_value = True + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + listener.on_down.side_effect = lambda h: Cluster.remove_host(cluster, h) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.metadata.remove_host.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + listener.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_queued_down_handling_after_remove_does_not_start_reconnector(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.metadata = Mock() + cluster.metadata.remove_host.return_value = True + host = self._make_host() + host.set_up() + + Cluster.on_down(cluster, host, is_host_addition=False) + assert len(executor.submissions) == 1 + + Cluster.remove_host(cluster, host) + executor.run_next() + + cluster.metadata.remove_host.assert_called_once_with(host) + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_not_called() + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).down_epoch is None + + def test_start_reconnector_rechecks_down_epoch_before_installing_handler(self): + cluster = self._make_cluster() + cluster.reconnection_policy = Mock() + cluster.reconnection_policy.new_schedule.return_value = iter([0]) + cluster.scheduler = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + state = self._state(cluster, host) + state.down_epoch = 1 + + def clear_down_epoch(h): + with h.lock: + self._state(cluster, h).down_epoch = None + return Mock() + + cluster._make_connection_factory = Mock(side_effect=clear_down_epoch) + Cluster._start_reconnector( + cluster, host, is_host_addition=False, expected_down_epoch=1) + + assert host._reconnection_handler is None + cluster.scheduler.schedule.assert_not_called() + + def test_on_up_queues_after_down_is_submitted_before_worker_runs(self): + executor = _QueuedExecutor() + session = Mock() + session.add_or_renew_pool.return_value = None + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + Cluster.on_down(cluster, host, is_host_addition=False) + state = self._state(cluster, host) + + assert len(executor.submissions) == 1 + assert state.down_epoch == state.epoch + + Cluster.on_up(cluster, host) + + assert state.pending_up_epoch == state.epoch + assert state.up_epoch is None + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + + executor.run_next() + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.on_down.assert_called_once_with( + host, expected_endpoint=host.endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_on_up_stays_queued_after_endpoint_update_before_down_worker_runs(self): + executor = _QueuedExecutor() + session = Mock() + session.add_or_renew_pool.return_value = None + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + state = self._state(cluster, host) + old_endpoint = host.endpoint + + host.endpoint = DefaultEndPoint("127.0.0.2") + + assert self._state(cluster, host) is state + + Cluster.on_up(cluster, host) + + assert state.down_epoch == state.epoch + assert state.pending_up_epoch == state.epoch + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + + executor.run_next() + + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + session.on_down.assert_called_once_with( + host, expected_endpoint=old_endpoint) + listener.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_down_for_replacement_endpoint_during_pending_old_down_is_handled(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=old_endpoint) + state = self._state(cluster, host) + assert state.down_epoch == state.epoch + + host.endpoint = new_endpoint + Cluster.on_up(cluster, host) + assert state.pending_up_epoch == state.epoch + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=new_endpoint) + + executor.run_next() + + assert len(executor.submissions) == 1 + executor.run_next() + + session.on_down.assert_any_call( + host, expected_endpoint=old_endpoint) + session.on_down.assert_any_call( + host, expected_endpoint=new_endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=new_endpoint) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_forced_down_for_replacement_endpoint_during_old_down_is_handled(self): + executor = _QueuedExecutor() + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint("127.0.0.2") + + Cluster.on_down( + cluster, host, is_host_addition=False, + expected_endpoint=old_endpoint) + state = self._state(cluster, host) + assert state.down_epoch == state.epoch + + host.endpoint = new_endpoint + Cluster.on_down( + cluster, host, is_host_addition=False, + expect_host_to_be_down=True, expected_endpoint=new_endpoint) + + executor.run_next() + + assert len(executor.submissions) == 1 + executor.run_next() + + session.on_down.assert_any_call( + host, expected_endpoint=old_endpoint) + session.on_down.assert_any_call( + host, expected_endpoint=new_endpoint) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=new_endpoint) + assert not host.is_up + assert state.down_epoch is None + assert state.up_epoch is None + assert state.pending_up_epoch is None + + def test_later_down_before_worker_runs_does_not_skip_pool_cleanup(self): + executor = _QueuedExecutor() + host = self._make_host() + host.set_up() + pool = Mock() + pool.host = host + pool.endpoint = host.endpoint + session = self._make_session_with_pool(host, pool) + session._lock = RLock() + cluster = self._make_cluster(session=session) + cluster.executor = executor + cluster.metadata = Mock() + cluster.metadata.all_hosts.return_value = [host] + session.cluster = cluster + session._profile_manager = cluster.profile_manager + + Cluster.on_down(cluster, host, is_host_addition=False) + Cluster.on_up(cluster, host) + Cluster.on_down(cluster, host, is_host_addition=False) + + future = executor.run_next() + + assert future.exception() is None + assert session._pools == {} + pool.shutdown.assert_called_once_with() + + def test_up_signal_waits_until_submitted_down_handling_finishes(self): + executor = _QueuedExecutor() + events = [] + session = Mock() + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster.executor = executor + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + + cluster.profile_manager.on_down.side_effect = lambda h: events.append("profile_down") + cluster.control_connection.on_down.side_effect = lambda h: events.append("control_down") + session.on_down.side_effect = lambda h, **kwargs: events.append("session_down") + listener.on_down.side_effect = lambda h: events.append("listener_down") + cluster._start_reconnector.side_effect = lambda h, is_host_addition, **kwargs: events.append("reconnector") + session.remove_pool.side_effect = lambda h, **kwargs: events.append("remove_pool") + cluster.profile_manager.on_up.side_effect = lambda h: events.append("profile_up") + cluster.control_connection.on_up.side_effect = lambda h: events.append("control_up") + session.add_or_renew_pool.side_effect = lambda h, is_host_addition, **kwargs: events.append("add_pool") + + Cluster.on_down(cluster, host, is_host_addition=False) + Cluster.on_up(cluster, host) + + assert events == [] + + executor.run_next() + + assert events == [ + "profile_down", + "control_down", + "session_down", + "listener_down", + "reconnector", + "remove_pool", + "profile_up", + "control_up", + "add_pool", + ] + assert host.is_up + + def test_on_up_queues_when_down_handling_is_active(self): + cluster = self._make_cluster() + cluster._prepare_all_queries = Mock() + host = self._make_host() + host.set_down() + down_epoch = self._reserve_down_handling(cluster, host) + + Cluster.on_up(cluster, host) + + cluster._prepare_all_queries.assert_not_called() + state = self._state(cluster, host) + assert state.down_epoch == down_epoch + assert state.up_epoch is None + assert state.pending_up_epoch == state.epoch + + def test_on_up_during_down_handling_is_replayed_for_ignored_host(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + listener.on_down.side_effect = lambda h: Cluster.on_up(cluster, h) + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + listener.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with(host, False, expected_down_epoch=ANY) + cluster.profile_manager.on_up.assert_called_once_with(host) + cluster.control_connection.on_up.assert_called_once_with(host) + assert host.is_up + assert self._state(cluster, host).down_epoch is None + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_later_down_during_down_handling_invalidates_queued_up(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + def queue_up_then_down(h): + Cluster.on_up(cluster, h) + assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch + Cluster.on_down(cluster, h, is_host_addition=False) + + listener.on_down.side_effect = queue_up_then_down + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).down_epoch is None + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_remove_during_down_handling_invalidates_queued_up(self): + listener = Mock() + cluster = self._make_cluster(listener=listener) + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_up() + down_epoch = self._reserve_down_handling(cluster, host) + + def queue_up_then_remove(h): + Cluster.on_up(cluster, h) + assert self._state(cluster, h).pending_up_epoch == self._state(cluster, h).epoch + Cluster.on_remove(cluster, h) + + listener.on_down.side_effect = queue_up_then_remove + + Cluster.on_down_potentially_blocking( + cluster, host, is_host_addition=False, down_epoch=down_epoch) + + cluster.profile_manager.on_remove.assert_called_once_with(host) + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert self._state(cluster, host).up_epoch is None + assert self._state(cluster, host).pending_up_epoch is None + + def test_stale_queued_up_replay_is_ignored_after_newer_down_event(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 1 + state.pending_up_epoch = 0 + + cluster._handle_pending_node_up(host, 0) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + assert state.up_epoch is None + + def test_stale_queued_up_replay_preserves_newer_pending_up(self): + cluster = self._make_cluster() + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.down_epoch = 2 + state.pending_up_epoch = 2 + + cluster._handle_pending_node_up(host, 1) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert state.pending_up_epoch == 2 + assert state.up_epoch is None + + def test_down_after_pending_up_pop_invalidates_replay(self): + cluster = self._make_cluster() + cluster.profile_manager.distance.return_value = HostDistance.IGNORED + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.pending_up_epoch = 2 + + with host.lock: + pending_up_epoch = cluster._pop_pending_node_up_if_ready(host) + + assert pending_up_epoch == (2, None) + assert state.pending_up_epoch == 2 + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert state.epoch == 3 + assert state.pending_up_epoch is None + + cluster._handle_pending_node_up(host, pending_up_epoch) + + cluster.profile_manager.on_up.assert_not_called() + cluster.control_connection.on_up.assert_not_called() + assert not host.is_up + + def test_down_for_down_host_with_pending_up_only_invalidates_pending_up(self): + cluster = self._make_cluster() + cluster.executor = Mock() + host = self._make_host() + host.set_down() + state = self._state(cluster, host) + state.epoch = 2 + state.down_epoch = 2 + state.pending_up_epoch = 2 + + Cluster.on_down(cluster, host, is_host_addition=False) + + cluster.executor.submit.assert_not_called() + assert state.down_epoch == 2 + assert state.pending_up_epoch is None + assert state.epoch == 3 + + def test_auth_failure_for_unknown_host_does_not_start_down_handling(self): + cluster = self._make_cluster() + host = self._make_host() + + is_down = cluster.signal_connection_failure( + host, AuthenticationFailed("bad credentials"), is_host_addition=False) + + assert is_down + assert host.is_up is None + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_wrapped_auth_failure_for_unknown_host_does_not_start_down_handling(self): + cluster = self._make_cluster() + host = self._make_host() + auth_exc = AuthenticationFailed("bad credentials") + conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc.__cause__ = auth_exc + + is_down = cluster.signal_connection_failure( + host, conn_exc, is_host_addition=False) + + assert is_down + assert host.is_up is None + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_connection_failure_after_endpoint_swap_is_ignored(self): + cluster = self._make_cluster() + host = self._make_host() + host.set_up() + old_endpoint = host.endpoint + host.endpoint = DefaultEndPoint("127.0.0.2") + + is_down = cluster.signal_connection_failure( + host, ConnectionException("failed", endpoint=old_endpoint), + is_host_addition=False, expected_endpoint=old_endpoint) + + assert not is_down + assert host.is_up + cluster.profile_manager.on_down.assert_not_called() + cluster.control_connection.on_down.assert_not_called() + cluster._start_reconnector.assert_not_called() + + def test_auth_failed_up_pool_for_unknown_host_rolls_back_without_reconnector(self): + pool_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pool_future + listener = Mock() + cluster = self._make_cluster(session=session, listener=listener) + cluster._prepare_all_queries = Mock() + cluster.profile_manager.distance.return_value = HostDistance.LOCAL + host = self._make_host() + auth_exc = AuthenticationFailed("bad credentials") + conn_exc = ConnectionException(str(auth_exc), endpoint=host) + conn_exc.__cause__ = auth_exc + + Cluster.on_up(cluster, host) + is_down = cluster.signal_connection_failure( + host, conn_exc, is_host_addition=False) + + assert is_down + assert host.is_up is None + session.remove_pool.reset_mock() + + pool_future.set_result(False) + + assert host.is_up is None + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=host.endpoint) + listener.on_up.assert_not_called() + cluster._start_reconnector.assert_not_called() + assert self._state(cluster, host).up_epoch is None + + def test_real_down_for_unknown_host_marks_host_down(self): + cluster = self._make_cluster() + host = self._make_host() + + Cluster.on_down(cluster, host, is_host_addition=False) + + assert host.is_up is False + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + + def test_expected_down_for_unknown_host_marks_host_down(self): + cluster = self._make_cluster() + host = self._make_host() + + Cluster.on_down( + cluster, host, is_host_addition=False, expect_host_to_be_down=True) + + assert host.is_up is False + cluster.profile_manager.on_down.assert_called_once_with(host) + cluster.control_connection.on_down.assert_called_once_with(host) + cluster._start_reconnector.assert_called_once_with( + host, False, expected_down_epoch=ANY, + expected_endpoint=host.endpoint) + class SessionTest(unittest.TestCase): def setUp(self): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..560d5290bf 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -13,15 +13,20 @@ # limitations under the License. import unittest +import uuid -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import (Cluster, ControlConnection, _Scheduler, ProfileManager, + EXEC_PROFILE_DEFAULT, ExecutionProfile) +from cassandra.metadata import Metadata from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import (ClientRoutesEndPoint, ConnectionException, + EndPoint, DefaultEndPoint, + DefaultEndPointFactory, SniEndPoint) from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -111,6 +116,7 @@ def __init__(self): self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(RoundRobinPolicy()) self.endpoint_factory = DefaultEndPointFactory().configure(self) self.ssl_options = None + self.down_expected_endpoint = None def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None): host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack, host_id=host_id) @@ -121,11 +127,17 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, def remove_host(self, host): pass - def on_up(self, host): + def _endpoints_match(self, endpoint, expected_endpoint): + return Cluster._endpoints_match(endpoint, expected_endpoint) + + def on_up(self, host, expected_endpoint=None): pass - def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + def on_down(self, host, is_host_addition, expect_host_to_be_down=False, + expected_endpoint=None, profile_manager_already_notified=False, + control_connection_already_notified=False): self.down_host = host + self.down_expected_endpoint = expected_endpoint def _node_meta_results(local_results, peer_results): @@ -143,6 +155,45 @@ def _node_meta_results(local_results, peer_results): return peer_response, local_response +class RunOnResultFuture(Future): + + def __init__(self, fn, args, kwargs): + Future.__init__(self) + self._fn = fn + self._args = args + self._kwargs = kwargs + + def _run_once(self): + if self.done(): + return + try: + self.set_result(self._fn(*self._args, **self._kwargs)) + except Exception as exc: + self.set_exception(exc) + + def result(self, timeout=None): + self._run_once() + return Future.result(self, timeout) + + +class RunOnResultExecutor(object): + + def __init__(self): + self.futures = [] + + def submit(self, fn, *args, **kwargs): + future = RunOnResultFuture(fn, args, kwargs) + self.futures.append(future) + return future + + def run_all(self): + for future in tuple(self.futures): + future._run_once() + + def shutdown(self): + self.run_all() + + class MockConnection(object): is_defunct = False @@ -378,6 +429,177 @@ def test_change_ip(self): assert 3 == len(self.cluster.metadata.all_hosts()) + def test_change_client_route_endpoint_when_only_port_changes(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + self.cluster.endpoint_factory = Mock() + self.cluster.endpoint_factory.create.return_value = new_endpoint + self.cluster.on_down = Mock() + self.cluster.on_up = Mock() + self.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "rpc_port", "peer", "data_center", "rack", "host_id"], + [["127.0.0.1", 9142, "127.0.0.1", "dc1", "rack1", host_id]])) + + self.control_connection._refresh_node_list_and_token_map( + self.connection, preloaded_results=preloaded_results) + + assert Cluster._endpoints_match(host.endpoint, new_endpoint) + self.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expect_host_to_be_down=True, + expected_endpoint=old_endpoint, + profile_manager_already_notified=True, + control_connection_already_notified=True) + self.cluster.on_up.assert_called_once_with( + host, expected_endpoint=new_endpoint) + + def test_endpoint_change_preserves_live_policy_hosts_when_down_handler_runs_late(self): + old_endpoint = DefaultEndPoint("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + host_id = uuid.uuid4() + policy = RoundRobinPolicy() + cluster = Cluster(contact_points=[], load_balancing_policy=policy) + original_executor = cluster.executor + cluster.executor = RunOnResultExecutor() + cluster._start_reconnector = Mock() + + try: + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.endpoint_factory = Mock() + cluster.endpoint_factory.create.return_value = new_endpoint + cluster.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "peer", "data_center", "rack", "host_id"], + [["127.0.0.2", "127.0.0.2", "dc1", "rack1", host_id]])) + + cluster.control_connection._refresh_node_list_and_token_map( + Mock(), preloaded_results=preloaded_results) + cluster.executor.run_all() + + assert list(policy._live_hosts) == [host] + assert host in policy._live_hosts + finally: + cluster.executor = original_executor + cluster.shutdown() + + def test_endpoint_change_reconnects_control_connection_when_down_handler_runs_late(self): + old_endpoint = DefaultEndPoint("127.0.0.1") + new_endpoint = DefaultEndPoint("127.0.0.2") + host_id = uuid.uuid4() + cluster = Cluster(contact_points=[]) + original_executor = cluster.executor + cluster.executor = RunOnResultExecutor() + cluster._start_reconnector = Mock() + cluster.control_connection._connection = Mock(endpoint=old_endpoint) + cluster.control_connection.reconnect = Mock() + + try: + host = Host( + old_endpoint, SimpleConvictionPolicy, + datacenter="dc1", rack="rack1", host_id=host_id) + host.set_up() + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.endpoint_factory = Mock() + cluster.endpoint_factory.create.return_value = new_endpoint + cluster.control_connection._token_meta_enabled = False + + preloaded_results = _node_meta_results( + local_results=([], []), + peer_results=( + ["rpc_address", "peer", "data_center", "rack", "host_id"], + [["127.0.0.2", "127.0.0.2", "dc1", "rack1", host_id]])) + + cluster.control_connection._refresh_node_list_and_token_map( + Mock(), preloaded_results=preloaded_results) + cluster.executor.run_all() + + cluster.control_connection.reconnect.assert_called_once_with() + finally: + cluster.control_connection._connection = None + cluster.executor = original_executor + cluster.shutdown() + + def test_stale_control_connection_failure_is_endpoint_fenced(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=host_id) + host.set_up() + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + host.endpoint = new_endpoint + self.cluster.metadata.update_host(host, old_endpoint) + assert self.cluster.metadata.get_host(old_endpoint) is host + + self.connection.endpoint = old_endpoint + self.connection.is_defunct = True + self.connection.last_error = ConnectionException( + "stale control connection failed", endpoint=old_endpoint) + self.cluster.signal_connection_failure = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, self.connection.last_error, is_host_addition=False, + expected_endpoint=old_endpoint) + + def test_stale_control_connection_failure_reconnects_when_cluster_ignores_signal(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), "127.0.0.1", original_port=9142) + + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=host_id) + host.set_up() + self.cluster.metadata = Metadata() + self.cluster.metadata.add_or_return_host(host) + host.endpoint = new_endpoint + self.cluster.metadata.update_host(host, old_endpoint) + + self.connection.endpoint = old_endpoint + self.connection.is_defunct = True + self.connection.last_error = ConnectionException( + "stale control connection failed", endpoint=old_endpoint) + self.cluster.signal_connection_failure = Mock(return_value=False) + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, self.connection.last_error, is_host_addition=False, + expected_endpoint=old_endpoint) + self.control_connection.reconnect.assert_called_once_with() def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): """ @@ -495,7 +717,8 @@ def test_handle_status_change(self): self.cluster.scheduler.reset_mock() self.control_connection._handle_status_change(event) host = self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.0')) - self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host) + self.cluster.scheduler.schedule_unique.assert_called_once_with( + ANY, self.cluster.on_up, host, expected_endpoint=host.endpoint) self.cluster.scheduler.schedule.reset_mock() event = { @@ -513,6 +736,33 @@ def test_handle_status_change(self): self.control_connection._handle_status_change(event) host = self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.0')) assert host is self.cluster.down_host + assert host.endpoint == self.cluster.down_expected_endpoint + + def test_handle_status_change_preserves_endpoint_identity(self): + host = self.cluster.metadata.hosts['uuid1'] + old_endpoint = SniEndPoint("192.168.1.0", "node-a", 9042) + new_endpoint = SniEndPoint("192.168.1.0", "node-b", 9042) + host.endpoint = old_endpoint + + event = { + 'change_type': 'UP', + 'address': ('192.168.1.0', 9042) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_status_change(event) + + # A raw event address would accept the replacement endpoint here. + assert Cluster._endpoints_match(new_endpoint, event['address']) + assert not Cluster._endpoints_match(new_endpoint, old_endpoint) + self.cluster.scheduler.schedule_unique.assert_called_once_with( + ANY, self.cluster.on_up, host, expected_endpoint=old_endpoint) + + self.cluster.on_down = Mock() + event['change_type'] = 'DOWN' + self.control_connection._handle_status_change(event) + + self.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=old_endpoint) def test_handle_schema_change(self): diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..0a7c6d07eb 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -20,11 +20,12 @@ from cassandra.shard_info import _ShardingInfo import unittest -from threading import Thread, Event, Lock +from threading import Thread, Event, Lock, Condition from unittest.mock import Mock, NonCallableMagicMock, MagicMock -from cassandra.cluster import Session, ShardAwareOptions -from cassandra.connection import Connection +from cassandra.cluster import Cluster, Session, ShardAwareOptions +from cassandra.connection import ClientRoutesEndPoint, Connection, DefaultEndPoint +from cassandra.metadata import Metadata from cassandra.pool import HostConnection from cassandra.pool import Host, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy @@ -133,6 +134,7 @@ def test_spawn_when_at_max(self): def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False) @@ -153,6 +155,7 @@ def test_return_defunct_connection(self): def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False, @@ -174,6 +177,8 @@ def test_return_defunct_connection_on_down_host(self): if self.PoolImpl is HostConnection: # on shard aware implementation we use submit function regardless assert host.signal_connection_failure.call_args + session.cluster.on_down.assert_called_once_with( + host, is_host_addition=False, expected_endpoint=pool.endpoint) assert session.submit.called else: assert not session.submit.called @@ -182,6 +187,7 @@ def test_return_defunct_connection_on_down_host(self): def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') + host.lock = Lock() session = self.make_session() conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100, signaled_error=False, orphaned_threshold_reached=False) @@ -200,6 +206,135 @@ def test_return_closed_connection(self): assert session.submit.call_args assert not pool.is_shutdown + def test_return_defunct_connection_after_endpoint_swap_is_ignored(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Mock(spec=Host, address='ip1') + host.endpoint = old_endpoint + host.lock = Lock() + session = self.make_session() + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + host.endpoint = new_endpoint + conn.is_defunct = True + host.signal_connection_failure.return_value = True + pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + conn.close.assert_called_once_with() + assert not pool.is_shutdown + + def test_return_defunct_connection_after_client_route_endpoint_port_swap_is_ignored(self): + host_id = uuid.uuid4() + old_endpoint = ClientRoutesEndPoint( + host_id, Mock(), '127.0.0.1', original_port=9042) + new_endpoint = ClientRoutesEndPoint( + host_id, Mock(), '127.0.0.1', original_port=9142) + assert old_endpoint == new_endpoint + assert not Cluster._endpoints_match(old_endpoint, new_endpoint) + host = Mock(spec=Host, address='ip1') + host.endpoint = old_endpoint + host.lock = Lock() + session = self.make_session() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + host.endpoint = new_endpoint + conn.is_defunct = True + host.signal_connection_failure.return_value = True + pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + conn.close.assert_called_once_with() + assert not pool.is_shutdown + + def test_return_defunct_connection_after_endpoint_reassignment_is_ignored(self): + endpoint = DefaultEndPoint('127.0.0.1') + stale_host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + stale_host.signal_connection_failure = Mock(return_value=True) + + session = self.make_session() + session.remove_pool.return_value = None + session.cluster.metadata = Metadata() + session.cluster.metadata.add_or_return_host(replacement_host) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100, + signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + pool = self.PoolImpl(stale_host, HostDistance.LOCAL, session) + + pool.borrow_connection(timeout=0.01) + conn.is_defunct = True + pool.return_connection(conn) + + stale_host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.submit.assert_not_called() + session.remove_pool.assert_called_once_with( + stale_host, expected_host=stale_host, + expected_endpoint=endpoint, expected_pool=pool) + conn.close.assert_called_once_with() + assert not pool.is_shutdown + + def test_return_defunct_connection_from_removed_pool_after_endpoint_flip_back_is_ignored(self): + endpoint = DefaultEndPoint('127.0.0.1') + host = Host(endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.signal_connection_failure = Mock(return_value=True) + + session = self.make_session() + session._lock = Lock() + session._pools = {} + session.remove_pool.return_value = None + session.cluster.metadata = Metadata() + session.cluster.metadata.add_or_return_host(host) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + conn = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100, + signaled_error=False, + orphaned_threshold_reached=False) + session.cluster.connection_factory.return_value = conn + + stale_pool = self.PoolImpl(host, HostDistance.LOCAL, session) + stale_pool.borrow_connection(timeout=0.01) + + # The old pool was already removed, then the host endpoint flipped back + # to the same endpoint and a replacement pool became current. + current_pool = Mock() + current_pool.host = host + current_pool.endpoint = endpoint + session._pools[host] = current_pool + + conn.is_defunct = True + stale_pool.return_connection(conn) + + host.signal_connection_failure.assert_not_called() + session.cluster.on_down.assert_not_called() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=endpoint, + expected_pool=stale_pool) + conn.close.assert_called_once_with() + assert not stale_pool.is_shutdown + def test_host_instantiations(self): """ Ensure Host fails if not initialized properly @@ -287,3 +422,121 @@ def mock_connection_factory(self, *args, **kwargs): # Cleanup executor with proper wait session.cluster.executor.shutdown(wait=True) + + def test_replace_retries_when_replacement_keyspace_set_fails(self): + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + initial_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection.set_keyspace_blocking.side_effect = RuntimeError( + "keyspace failed") + session.cluster.connection_factory.side_effect = [ + initial_connection, replacement_connection] + + pool = HostConnection(host, HostDistance.LOCAL, session) + pool._is_replacing = True + + pool._replace(initial_connection) + + assert session.submit.call_count == 1 + submitted_fn, submitted_connection = session.submit.call_args.args + assert submitted_fn == pool._replace + assert submitted_connection is initial_connection + + def test_replace_discards_replacement_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + initial_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.side_effect = [ + initial_connection, replacement_connection] + + pool = HostConnection(host, HostDistance.LOCAL, session) + pool._is_replacing = True + + pool._replace(initial_connection) + + replacement_connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert not pool._is_replacing + + def test_missing_shard_discards_connection_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.sharding_info = _ShardingInfo( + shard_id=0, shards_count=1, partitioner='', + sharding_algorithm='', sharding_ignore_msb=0, + shard_aware_port='', shard_aware_port_ssl='') + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster.ssl_options = None + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.return_value = connection + + pool = HostConnection.__new__(HostConnection) + pool.host = host + pool.endpoint = old_endpoint + pool.host_distance = HostDistance.LOCAL + pool.is_shutdown = False + pool._session = session + pool._lock = Lock() + pool._stream_available_condition = Condition(Lock()) + pool._connections = {} + pool._pending_connections = [] + pool._connecting = {0} + pool._excess_connections = set() + pool._trash = set() + pool._shard_connections_futures = [] + pool._keyspace = 'ks' + pool.advanced_shardaware_block_until = 0 + pool.tablets_routing_v1 = False + + pool._open_connection_to_missing_shard(0) + + connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert pool._connecting == set() diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..8aca91567a 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -13,14 +13,15 @@ # limitations under the License. import unittest +import uuid from collections import deque from threading import RLock from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion -from cassandra.connection import Connection, ConnectionException +from cassandra.cluster import Cluster, Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.connection import Connection, ConnectionException, DefaultEndPoint from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, @@ -28,8 +29,8 @@ RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, ProtocolHandler) -from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy -from cassandra.pool import NoConnectionsAvailable +from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy, SimpleConvictionPolicy +from cassandra.pool import Host, NoConnectionsAvailable from cassandra.query import SimpleStatement from tests.util import assertEqual, assertIsInstance import pytest @@ -37,6 +38,24 @@ class ResponseFutureTests(unittest.TestCase): + class _EndpointSwapOnExitLock(object): + + def __init__(self, host, new_endpoint): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._exits = 0 + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self._host.endpoint = self._new_endpoint + def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] @@ -52,7 +71,7 @@ def make_pool(self): def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value = self.make_pool() + session._get_pool_by_host_identity.return_value = self.make_pool() return session def make_response_future(self, session): @@ -66,7 +85,7 @@ def make_mock_response(self, col_names, rows): def test_result_message(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value pool.is_shutdown = False connection = Mock(spec=Connection) @@ -75,7 +94,7 @@ def test_result_message(self): rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -87,7 +106,7 @@ def test_result_message(self): def test_unknown_result_class(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -151,7 +170,7 @@ def test_heartbeat_defunct_deadlock(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()] - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool query = SimpleStatement("SELECT * FROM foo") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) @@ -252,7 +271,7 @@ def test_retry_policy_says_ignore(self): def test_retry_policy_says_retry(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) @@ -266,7 +285,7 @@ def test_retry_policy_says_retry(self): rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -285,13 +304,13 @@ def test_retry_policy_says_retry(self): # it should try again with the same host since this was # an UnavailableException - rf.session._pools.get.assert_called_with(host) + rf.session._get_pool_by_host_identity.assert_called_with(host) pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -300,7 +319,7 @@ def test_retry_with_different_host(self): rf.message.consistency_level = ConsistencyLevel.QUORUM rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) assert ConsistencyLevel.QUORUM == rf.message.consistency_level @@ -319,7 +338,7 @@ def test_retry_with_different_host(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -328,13 +347,13 @@ def test_retry_with_different_host(self): def test_all_retries_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -346,7 +365,7 @@ def test_all_retries_fail(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') result = Mock(spec=IsBootstrappingErrorMessage, info={}) rf._set_result(host, None, None, result) @@ -360,7 +379,7 @@ def test_all_retries_fail(self): def test_exponential_retry_policy_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -368,7 +387,7 @@ def test_exponential_retry_policy_fail(self): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1, retry_policy=ExponentialBackoffRetryPolicy(2)) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -384,7 +403,7 @@ def test_exponential_retry_policy_fail(self): def test_all_pools_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value.is_shutdown = True + session._get_pool_by_host_identity.return_value.is_shutdown = True rf = ResponseFuture(session, Mock(), Mock(), 1) rf.send_request() @@ -399,7 +418,7 @@ def test_first_pool_shutdown(self): pool_shutdown.is_shutdown = True pool_ok = self.make_pool() pool_ok.is_shutdown = True - session._pools.get.side_effect = [pool_shutdown, pool_ok] + session._get_pool_by_host_identity.side_effect = [pool_shutdown, pool_ok] rf = self.make_response_future(session) rf.send_request() @@ -424,7 +443,7 @@ def test_timeout_getting_connection_from_pool(self): connection = Mock(spec=Connection) second_pool.borrow_connection.return_value = (connection, 1) - session._pools.get.side_effect = [first_pool, second_pool] + session._get_pool_by_host_identity.side_effect = [first_pool, second_pool] rf = self.make_response_future(session) rf.send_request() @@ -459,7 +478,7 @@ def test_callback(self): def test_errback(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -508,7 +527,7 @@ def test_multiple_callbacks(self): def test_multiple_errbacks(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -581,7 +600,7 @@ def test_add_callbacks(self): def test_prepared_query_not_found(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -606,7 +625,7 @@ def test_prepared_query_not_found(self): def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -655,7 +674,7 @@ def test_timeout_does_not_release_stream_id(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')] pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), orphaned_request_ids=set(), orphaned_threshold=256, in_flight=3) pool.borrow_connection.return_value = (connection, 1) @@ -675,6 +694,127 @@ def test_timeout_does_not_release_stream_id(self): assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) + def test_timeout_returns_orphan_to_original_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + old_pool = self.make_pool() + replacement_pool = self.make_pool() + connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), + orphaned_request_ids=set(), orphaned_threshold=256, in_flight=1) + old_pool.borrow_connection.return_value = (connection, 1) + session._get_pool_by_host_identity.side_effect = [old_pool, replacement_pool] + + rf = self.make_response_future(session) + rf.send_request() + connection._requests[1] = (connection._handle_options_response, + ProtocolHandler.decode_message, []) + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf._on_timeout() + + replacement_pool.return_connection.assert_not_called() + old_pool.return_connection.assert_called_once_with( + connection, stream_was_orphaned=True) + + def test_query_does_not_borrow_stale_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + + def test_query_rechecks_endpoint_after_pool_lookup_race(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.lock = self._EndpointSwapOnExitLock( + host, DefaultEndPoint('127.0.0.2')) + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + + def test_query_releases_request_id_after_post_borrow_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint('127.0.0.2') + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = old_endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection, lock=RLock(), request_ids=deque([1]), + in_flight=0) + + def borrow_connection(**kwargs): + with connection.lock: + connection.in_flight += 1 + request_id = connection.request_ids.popleft() + host.endpoint = new_endpoint + return connection, request_id + + def return_connection(returned_connection): + with returned_connection.lock: + returned_connection.in_flight -= 1 + + stale_pool.borrow_connection.side_effect = borrow_connection + stale_pool.return_connection.side_effect = return_connection + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.return_connection.assert_called_once_with(connection) + connection.send_msg.assert_not_called() + assert list(connection.request_ids) == [1] + assert connection.in_flight == 0 + assert isinstance(rf._errors[host], ConnectionException) + def test_single_host_query_plan_exhausted_after_one_retry(self): """ Test that when a specific host is provided, the query plan is properly @@ -686,7 +826,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): """ session = self.make_basic_session() pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool # Create a specific host specific_host = Mock() @@ -702,7 +842,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): rf.send_request() # Verify initial request was sent - rf.session._pools.get.assert_called_once_with(specific_host) + rf.session._get_pool_by_host_identity.assert_called_once_with(specific_host) pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])