From 81cec2b0d7216f602cac90896677dad171e6425d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 23 Apr 2026 14:28:12 -0400 Subject: [PATCH 01/14] cluster: use dynamic whitelist policy for proxy access --- cassandra/cluster.py | 260 ++++++++++-- cassandra/datastax/insights/serializers.py | 8 + cassandra/policies.py | 71 ++++ cassandra/pool.py | 15 +- docs/api/cassandra/policies.rst | 3 + .../standard/test_client_routes.py | 77 +++- tests/unit/advanced/test_insights.py | 10 + tests/unit/test_cluster.py | 372 +++++++++++++++++- tests/unit/test_control_connection.py | 89 +++++ tests/unit/test_host_connection_pool.py | 48 ++- tests/unit/test_policies.py | 39 +- tests/unit/test_shard_aware.py | 3 + 12 files changed, 927 insertions(+), 68 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..5ef7053806 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -548,6 +548,10 @@ def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) + def on_control_connection_host(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_control_connection_host(host) + @property def default(self): """ @@ -1673,6 +1677,7 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5): self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) + profile.load_balancing_policy.on_control_connection_host(self.get_control_connection_host()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) @@ -1746,6 +1751,17 @@ def connect(self, keyspace=None, wait_for_all_pools=False): established or attempted. Default is `False`, which means it will return when the first successful connection is established. Remaining pools are added asynchronously. """ + self._ensure_core_connections_setup() + + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + + return session + + def _ensure_core_connections_setup(self): with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") @@ -1777,14 +1793,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): ) self._is_setup = True - session = self._new_session(keyspace) - if wait_for_all_pools: - wait_futures(session._initial_connect_futures) - - self._set_default_dbaas_consistency(session) - - return session - def _set_default_dbaas_consistency(self, session): if session.cluster.metadata.dbaas: for profile in self.profile_manager.profiles.values(): @@ -1805,14 +1813,18 @@ def get_all_pools(self): pools.extend(s.get_pools()) return pools + def _get_shard_aware_pools(self): + return [pool for pool in self.get_all_pools() if pool.host.sharding_info is not None] + def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + return bool(self._get_shard_aware_pools()) def shard_aware_stats(self): - if self.is_shard_aware(): + shard_aware_pools = self._get_shard_aware_pools() + if shard_aware_pools: return {str(pool.host.endpoint): {'shards_count': pool.host.sharding_info.shards_count, 'connected': len(pool._connections.keys())} - for pool in self.get_all_pools()} + for pool in shard_aware_pools} def shutdown(self): """ @@ -1857,11 +1869,91 @@ def _new_session(self, keyspace): self.sessions.add(session) return session + def _default_control_connection_endpoint_targets_host( + self, host, endpoint, attempts=3): + for _ in range(attempts): + connected_host_id = self._get_host_id_for_endpoint(endpoint) + if connected_host_id != host.host_id: + return False + return True + + def _get_host_id_for_endpoint(self, endpoint): + connection = None + try: + connection = self.connection_factory(endpoint) + response = connection.wait_for_response( + QueryMessage( + query="SELECT host_id FROM system.local WHERE key='local'", + consistency_level=ConsistencyLevel.ONE), + timeout=self.connect_timeout) + rows = dict_factory(response.column_names, response.parsed_rows) + if not rows: + return None + return rows[0].get("host_id") + except Exception: + log.debug( + "Failed verifying control connection endpoint %s", endpoint, + exc_info=True) + return None + finally: + if connection: + connection.close() + + def _get_control_connection_host_endpoint(self, control_host, connection_endpoint): + if connection_endpoint is not None and ( + connection_endpoint == control_host.endpoint or + not isinstance(connection_endpoint, DefaultEndPoint)): + return connection_endpoint + + host_endpoint = control_host.endpoint + if host_endpoint is not None and not isinstance(host_endpoint, DefaultEndPoint): + return host_endpoint + + if connection_endpoint is not None and self._default_control_connection_endpoint_targets_host( + control_host, connection_endpoint): + return connection_endpoint + + if connection_endpoint is not None: + return connection_endpoint + + return host_endpoint + def _session_register_user_types(self, session): for keyspace, type_map in self._user_types.items(): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) + def _update_host_endpoint(self, host, endpoint): + if host.endpoint == endpoint: + return + + was_up = host.is_up + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + + if was_up: + host.set_down() + self.profile_manager.on_down(host) + for session in tuple(self.sessions): + session.remove_pool(host) + if was_up: + for listener in self.listeners: + listener.on_down(host) + + old_endpoint = host.endpoint + host.endpoint = endpoint + self.metadata.update_host(host, old_endpoint) + if was_up: + host.set_up() + self.profile_manager.on_up(host) + for session in tuple(self.sessions): + session.add_or_renew_pool(host, is_host_addition=False) + for listener in self.listeners: + listener.on_up(host) + else: + self._start_reconnector(host, is_host_addition=False) + def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) @@ -1954,12 +2046,16 @@ def on_up(self, host): futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback_futures = [] for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + callback_futures.append(future) + + for future in callback_futures: + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: @@ -2030,8 +2126,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in tuple(self.sessions): - pool_states = session.get_pool_state() - pool_state = pool_states.get(host) + pool_state = session.get_pool_state_for_host(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: @@ -2220,7 +2315,18 @@ def get_control_connection_host(self): """ connection = self.control_connection._connection endpoint = connection.endpoint if connection else None - return self.metadata.get_host(endpoint) if endpoint else None + if not endpoint: + return None + + host = self.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self.control_connection._current_host_id + if host_id is None: + return None + + return self.metadata.get_host_by_host_id(host_id) def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ @@ -2924,8 +3030,11 @@ def _on_analytics_master_result(self, response, master_future, query_future): delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided if delimiter_index > 0: addr = addr[:delimiter_index] - targeted_query = HostTargetingStatement(query_future.query, addr) - query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + if query_future._host is not None: + query_future.query_plan = iter([query_future._host]) + else: + targeted_query = HostTargetingStatement(query_future.query, addr) + query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) except Exception: log.debug("Failed querying analytics master (request might not be routed optimally). " "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) @@ -3245,15 +3354,17 @@ def run_add_or_renew_pool(): new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down(host, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created - self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down( + host, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) @@ -3271,7 +3382,7 @@ def callback(pool, errors): set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) + self._handle_pool_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False @@ -3410,9 +3521,21 @@ def submit(self, fn, *args, **kwargs): def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + def get_pool_state_for_host(self, host): + return self.get_pool_state().get(host) + def get_pools(self): return self._pools.values() + def _signal_connection_failure(self, host, connection_exc): + return host.signal_connection_failure(connection_exc) + + def _handle_pool_down(self, host, is_host_addition, expect_host_to_be_down=False): + self.cluster.on_down(host, is_host_addition, expect_host_to_be_down) + + def is_shard_aware_disabled(self): + return self.cluster.shard_aware_options.disable + def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) @@ -3545,6 +3668,7 @@ def __init__(self, cluster, timeout, self._reconnection_handler = None self._reconnection_lock = RLock() + self._current_host_id = None self._event_schedule_times = {} @@ -3569,6 +3693,9 @@ def _set_new_connection(self, conn): log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() + for session in tuple(getattr(self._cluster, "sessions", ())): + session.update_created_pools() + def _try_connect_to_hosts(self): errors = {} @@ -3770,6 +3897,11 @@ def shutdown(self): if self._connection: self._connection.close() self._connection = None + self._current_host_id = None + try: + self._cluster.profile_manager.on_control_connection_host(None) + except ReferenceError: + pass def refresh_schema(self, force=False, **kwargs): try: @@ -3849,6 +3981,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids = set() found_endpoints = set() + local_host_id = None if local_result.parsed_rows: local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] @@ -3857,6 +3990,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, partitioner = local_row.get("partitioner") tokens = local_row.get("tokens", None) + local_host_id = local_row.get("host_id") peers_result.insert(0, local_row) @@ -3887,17 +4021,17 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) - if host and host.endpoint != endpoint: - log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) - - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) + if host: + target_endpoint = endpoint + if host_id == local_host_id: + target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection.endpoint) + if target_endpoint is None: + target_endpoint = endpoint + + if host.endpoint != target_endpoint: + log.debug("[control connection] Updating host endpoint from %s to %s for (%s)", + host.endpoint, target_endpoint, host_id) + self._cluster._update_host_endpoint(host, target_endpoint) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) @@ -3928,6 +4062,19 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.remove_host_by_host_id(old_host_id, old_host.endpoint) log.debug("[control connection] Finished fetching ring info") + current_host = None + if local_host_id in found_host_ids: + current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) + if current_host is not None: + self._maybe_rebind_control_connection_host_endpoint(current_host, connection.endpoint) + current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) + + previous_current_host_id = self._current_host_id + self._current_host_id = local_host_id if current_host is not None else None + self._cluster.profile_manager.on_control_connection_host(current_host) + if connection is self._connection and self._current_host_id != previous_current_host_id: + for session in tuple(getattr(self, "sessions", ())): + session.update_created_pools() if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) @@ -3981,6 +4128,15 @@ def _update_location_info(self, host, datacenter, rack): self._cluster.profile_manager.on_up(host) return True + def _maybe_rebind_control_connection_host_endpoint(self, host, connection_endpoint): + target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection_endpoint) + if target_endpoint is None or target_endpoint == host.endpoint: + return + + log.debug("[control connection] Rebasing current host %s from %s to %s", + host.host_id, host.endpoint, target_endpoint) + self._cluster._update_host_endpoint(host, target_endpoint) + def _delay_for_event_type(self, event_type, delay_window): # this serves to order processing correlated events (received within the window) # the window and randomization still have the desired effect of skew across client instances @@ -4222,25 +4378,50 @@ def _signal_error(self): # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.endpoint) + host = self._try_get_cluster_host() # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + is_down = self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) - return + if is_down: + return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() + def _try_get_cluster_host(self): + conn = self._connection + endpoint = conn.endpoint if conn else None + if not endpoint: + return None + + host = self._cluster.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self._current_host_id + if host_id is None: + return None + + return self._cluster.metadata.get_host_by_host_id(host_id) + def on_up(self, host): pass - def on_down(self, host): - + def _is_current_host(self, host): conn = self._connection - if conn and conn.endpoint == host.endpoint and \ + if conn is None or host is None: + return False + + if conn.endpoint == host.endpoint: + return True + + return self._current_host_id is not None and getattr(host, 'host_id', None) == self._current_host_id + + def on_down(self, host): + if self._is_current_host(host) and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) @@ -4252,8 +4433,7 @@ def on_add(self, host, refresh_nodes=True): self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): - c = self._connection - if c and c.endpoint == host.endpoint: + if self._is_current_host(host): log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py index 289c165e8a..270b5360a3 100644 --- a/cassandra/datastax/insights/serializers.py +++ b/cassandra/datastax/insights/serializers.py @@ -37,6 +37,7 @@ def initialize_registry(insights_registry): DCAwareRoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, HostFilterPolicy, ConstantReconnectionPolicy, ExponentialReconnectionPolicy, @@ -80,6 +81,13 @@ def whitelist_round_robin_policy_insights_serializer(policy): 'options': {'allowed_hosts': policy._allowed_hosts} } + @insights_registry.register_serializer_for(DynamicWhiteListRoundRobinPolicy) + def dynamic_whitelist_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'allowed_host_ids': tuple(str(host_id) for host_id in policy._allowed_host_ids)} + } + @insights_registry.register_serializer_for(HostFilterPolicy) def host_filter_policy_insights_serializer(policy): return { diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..268a260812 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -166,6 +166,16 @@ def check_supported(self): """ pass + def on_control_connection_host(self, host): + """ + Called when the control connection resolves the metadata host behind + the endpoint it is currently using. + + Policies that maintain a dynamic host allowlist can override this to + update their internal view of the cluster. + """ + pass + class RoundRobinPolicy(LoadBalancingPolicy): """ @@ -540,6 +550,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + class WhiteListRoundRobinPolicy(RoundRobinPolicy): """ @@ -594,6 +607,58 @@ def on_add(self, host): RoundRobinPolicy.on_add(self, host) +class DynamicWhiteListRoundRobinPolicy(RoundRobinPolicy): + """ + A :class:`.RoundRobinPolicy` variant whose allowlist is updated from the + control connection. + + This is intended for proxy deployments where the driver can only reach the + host currently behind the control connection endpoint. The policy keeps + every other discovered node at :attr:`~.HostDistance.IGNORED` until the + control connection resolves a different host. + """ + + def __init__(self): + self._allowed_host_ids = frozenset(()) + self._cluster = None + RoundRobinPolicy.__init__(self) + + def _host_is_allowed(self, host): + return getattr(host, "host_id", None) in self._allowed_host_ids + + def _refresh_live_hosts(self, hosts): + self._live_hosts = frozenset( + host for host in hosts + if self._host_is_allowed(host) and host.is_up is not False + ) + + def populate(self, cluster, hosts): + self._cluster = cluster + self._refresh_live_hosts(hosts) + if len(self._live_hosts) > 1: + self._position = randint(0, len(self._live_hosts) - 1) + else: + self._position = 0 + + def distance(self, host): + return HostDistance.LOCAL if self._host_is_allowed(host) else HostDistance.IGNORED + + def on_up(self, host): + if self._host_is_allowed(host): + RoundRobinPolicy.on_up(self, host) + + def on_add(self, host): + if self._host_is_allowed(host): + RoundRobinPolicy.on_add(self, host) + + def on_control_connection_host(self, host): + with self._hosts_lock: + allowed_host_id = getattr(host, "host_id", None) + self._allowed_host_ids = frozenset((allowed_host_id,)) if allowed_host_id is not None else frozenset(()) + if self._cluster is not None: + self._refresh_live_hosts(self._cluster.metadata.all_hosts()) + + class HostFilterPolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` subclass configured with a child policy, @@ -654,6 +719,9 @@ def on_add(self, host, *args, **kwargs): def on_remove(self, host, *args, **kwargs): return self._child_policy.on_remove(host, *args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + @property def predicate(self): """ @@ -1322,6 +1390,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + class DefaultLoadBalancingPolicy(WrapperPolicy): """ diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..097824be06 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session): if self._keyspace: first_connection.set_keyspace_blocking(self._keyspace) - if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: + if first_connection.features.sharding_info and not self._session.is_shard_aware_disabled(): self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) self.tablets_routing_v1 = first_connection.features.tablets_routing_v1 @@ -451,7 +451,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table raise NoConnectionsAvailable() shard_id = None - if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key: + if not self._session.is_shard_aware_disabled() and self.host.sharding_info and routing_key: t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key) shard_id = None @@ -554,7 +554,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.signal_connection_failure(connection.last_error) + is_down = self._session._signal_connection_failure(self.host, connection.last_error) connection.signaled_error = True if self.shutdown_on_error and not is_down: @@ -562,7 +562,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if is_down: self.shutdown() - self._session.cluster.on_down(self.host, is_host_addition=False) + self._session._handle_pool_down(self.host, is_host_addition=False) else: connection.close() with self._lock: @@ -603,7 +603,7 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if self.host.sharding_info and not self._session.is_shard_aware_disabled(): self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) else: @@ -678,7 +678,8 @@ def disable_advanced_shard_aware(self, secs): def _get_shard_aware_endpoint(self): if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ - self._session.cluster.shard_aware_options.disable_shardaware_port: + self._session.cluster.shard_aware_options.disable_shardaware_port or \ + self._session.is_shard_aware_disabled(): return None endpoint = None @@ -920,5 +921,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 84d5575a40..2a24af8f9f 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -26,6 +26,9 @@ Load Balancing .. autoclass:: WhiteListRoundRobinPolicy :members: +.. autoclass:: DynamicWhiteListRoundRobinPolicy + :members: + .. autoclass:: TokenAwarePolicy :members: diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..a9efcd65b4 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -40,8 +40,8 @@ from cassandra.cluster import Cluster from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy -from cassandra.connection import ClientRoutesEndPoint -from cassandra.policies import RoundRobinPolicy +from cassandra.connection import ClientRoutesEndPoint, ConnectionException +from cassandra.policies import DynamicWhiteListRoundRobinPolicy, RoundRobinPolicy from tests.integration import ( TestCluster, get_cluster, @@ -54,6 +54,28 @@ log = logging.getLogger(__name__) + +class ProxyOnlyReachableConnection(Cluster.connection_class): + """ + Simulates a private-link client that can reach only the proxy endpoint. + + The CCM node addresses are reachable from the local test runner, which means + the existing client-routes tests cannot reproduce bugs that only appear when + direct node IPs are private. This connection class rejects those direct node + addresses while still allowing the NLB address. + """ + + @classmethod + def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs): + address, _ = endpoint.resolve() + if address.startswith("127.0.0."): + raise ConnectionException( + "Simulated private node address %s is unreachable from the client" % address, + endpoint=endpoint, + ) + return super().factory(endpoint, timeout, host_conn=host_conn, *args, **kwargs) + + class TcpProxy: """ A simple TCP proxy that forwards connections from a local listen port @@ -535,6 +557,57 @@ def teardown_module(): else: os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + +class TestProxyConnectivityWithoutClientRoutes(unittest.TestCase): + """ + Reproducer for connecting through a generic proxy when node addresses are + not reachable from the client. + + The initial control connection can reach the cluster through the proxy, but + the driver later tries to open pools to the discovered node addresses + directly. In a proxy-only environment that makes connect/query fail. + """ + + @classmethod + def setUpClass(cls): + cls.node_addrs = { + 1: "127.0.0.1", + 2: "127.0.0.2", + 3: "127.0.0.3", + } + cls.proxy_node_id = 1 + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + @classmethod + def tearDownClass(cls): + cls.nlb.stop() + + def _make_proxy_cluster(self): + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.node_port(self.proxy_node_id), + connection_class=ProxyOnlyReachableConnection, + load_balancing_policy=DynamicWhiteListRoundRobinPolicy(), + ) + + def test_dynamic_whitelist_session_succeeds_when_only_proxy_is_reachable(self): + cluster = self._make_proxy_cluster() + self.addCleanup(cluster.shutdown) + + session = cluster.connect() + row = session.execute( + "SELECT release_version FROM system.local WHERE key='local'" + ).one() + + self.assertIsNotNone(row) + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), 1) + + session_host = next(iter(pool_state)) + self.assertEqual(session_host.endpoint.address, NLBEmulator.LISTEN_HOST) + self.assertEqual(session_host.endpoint.port, self.nlb.node_port(self.proxy_node_id)) + @skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', scylla_version="2026.1.0") class TestGetHostPortMapping(unittest.TestCase): diff --git a/tests/unit/advanced/test_insights.py b/tests/unit/advanced/test_insights.py index ec9b918866..55ea05e8f4 100644 --- a/tests/unit/advanced/test_insights.py +++ b/tests/unit/advanced/test_insights.py @@ -17,6 +17,7 @@ import logging import sys +import uuid from unittest.mock import sentinel from cassandra import ConsistencyLevel @@ -37,6 +38,7 @@ DCAwareRoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, HostFilterPolicy, ConstantReconnectionPolicy, ExponentialReconnectionPolicy, @@ -203,6 +205,14 @@ def test_whitelist_round_robin_policy(self): 'options': {'allowed_hosts': ('127.0.0.3',)}, 'type': 'WhiteListRoundRobinPolicy'} + def test_dynamic_whitelist_round_robin_policy(self): + policy = DynamicWhiteListRoundRobinPolicy() + host_id = uuid.uuid4() + policy._allowed_host_ids = (host_id,) + assert insights_registry.serialize(policy) == {'namespace': 'cassandra.policies', + 'options': {'allowed_host_ids': (str(host_id),)}, + 'type': 'DynamicWhiteListRoundRobinPolicy'} + def test_host_filter_policy(self): def my_predicate(s): return False diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..16e28322d6 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,7 +15,7 @@ import logging import socket - +from concurrent.futures import Future from unittest.mock import patch, Mock import uuid @@ -23,8 +23,10 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import DefaultEndPoint, EndPoint from cassandra.pool import Host -from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy +from cassandra.policies import RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, \ + DynamicWhiteListRoundRobinPolicy, HostStateListener, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools from tests import connection_class @@ -33,6 +35,57 @@ log = logging.getLogger(__name__) + +class _HostAwareProxyEndPoint(EndPoint): + def __init__(self, address, affinity_key, port=9042): + self._address = address + self._affinity_key = affinity_key + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _HostAwareProxyEndPoint) and \ + self.address == other.address and self.port == other.port and \ + self._affinity_key == other._affinity_key + + def __hash__(self): + return hash((self.address, self.port, self._affinity_key)) + + def __lt__(self, other): + if not isinstance(other, _HostAwareProxyEndPoint): + return NotImplemented + return (self.address, self.port, str(self._affinity_key)) < \ + (other.address, other.port, str(other._affinity_key)) + + +class _RecordingHostStateListener(HostStateListener): + + def __init__(self): + self.events = [] + + def on_up(self, host): + self.events.append(("up", host.address)) + + def on_down(self, host): + self.events.append(("down", host.address)) + + def on_add(self, host): + self.events.append(("add", host.address)) + + def on_remove(self, host): + self.events.append(("remove", host.address)) + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -229,6 +282,134 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_get_control_connection_host_falls_back_to_host_id(self): + cluster = Cluster(contact_points=['127.0.0.1']) + host = Host(DefaultEndPoint('192.168.1.10'), SimpleConvictionPolicy, host_id=uuid.uuid4()) + + metadata = Mock() + metadata.get_host.return_value = None + metadata.get_host_by_host_id.return_value = host + cluster.metadata = metadata + + connection = Mock(endpoint=DefaultEndPoint('127.254.254.101', 9042)) + cluster.control_connection = Mock(_connection=connection, _current_host_id=host.host_id) + + assert cluster.get_control_connection_host() is host + metadata.get_host.assert_called_once_with(connection.endpoint) + metadata.get_host_by_host_id.assert_called_once_with(host.host_id) + + def test_update_host_endpoint_recreates_session_pools(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + remove_future = Future() + remove_future.set_result(None) + add_future = Future() + add_future.set_result(True) + + session = Mock() + session.remove_pool.return_value = remove_future + session.add_or_renew_pool.return_value = add_future + cluster.sessions.add(session) + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert host.endpoint == new_endpoint + assert cluster.metadata.get_host(new_endpoint) is host + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + + def test_update_host_endpoint_restarts_reconnector_for_down_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_down() + cluster.metadata.add_or_return_host(host) + + previous_reconnector = Mock() + host.get_and_set_reconnection_handler(previous_reconnector) + + session = Mock() + cluster.sessions.add(session) + cluster._start_reconnector = Mock() + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert host.endpoint == new_endpoint + assert cluster.metadata.get_host(new_endpoint) is host + previous_reconnector.cancel.assert_called_once_with() + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_not_called() + cluster._start_reconnector.assert_called_once_with(host, is_host_addition=False) + + def test_update_host_endpoint_notifies_listeners_for_live_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + session = Mock() + cluster.sessions.add(session) + + listener = _RecordingHostStateListener() + cluster.register_listener(listener) + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert listener.events == [("down", "127.0.0.1"), ("up", "127.0.0.2")] + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + + def test_is_shard_aware_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[control_pool, shard_pool]) + + assert cluster.is_shard_aware() is True + + def test_shard_aware_stats_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[shard_pool, control_pool]) + + assert cluster.shard_aware_stats() == { + "127.0.0.1:9042": {"shards_count": 8, "connected": 2} + } + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -252,6 +433,16 @@ def setUp(self): raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + @staticmethod + def _completed_future(result): + future = Future() + future.set_result(result) + return future + + @staticmethod + def _proxy_endpoint(address, affinity_key, port=9042): + return _HostAwareProxyEndPoint(address, affinity_key, port) + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -339,6 +530,163 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + def test_get_control_connection_host_endpoint_reuses_matching_default_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(source_host.host_id,)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 3 + assert verification_connection.close.call_count == 3 + + def test_get_control_connection_host_endpoint_prefers_host_aware_metadata_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_endpoint = self._proxy_endpoint("proxy.control.example", host_id) + source_host = Host(source_endpoint, SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + endpoint = cluster._get_control_connection_host_endpoint( + source_host, DefaultEndPoint("127.254.254.101")) + + assert endpoint == source_endpoint + + def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verification_mismatches(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(uuid.uuid4(),)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 1 + assert verification_connection.close.call_count == 1 + + def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verification_fails(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.side_effect = RuntimeError("verification failed") + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 1 + assert verification_connection.close.call_count == 1 + + def test_analytics_master_lookup_keeps_explicit_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + target_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + target_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = Session(cluster, [target_host]) + + master_future = Mock() + master_future.result.return_value = [({'location': '127.0.0.99:8182'},)] + + query_future = Mock() + query_future._host = target_host + query_future.query = SimpleStatement("g.V()") + query_future._load_balancer = Mock() + query_future.send_request = Mock() + query_future.query_plan = iter(()) + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))): + session._on_analytics_master_result(None, master_future, query_future) + + assert list(query_future.query_plan) == [target_host] + query_future._load_balancer.make_query_plan.assert_not_called() + query_future.send_request.assert_called_once_with() + + @mock_session_pools + def test_session_preserves_down_event_discounting_after_endpoint_update(self, *_): + class _DeterministicHashEndPoint(EndPoint): + def __init__(self, address, hash_value, port=9042): + self._address = address + self._hash_value = hash_value + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _DeterministicHashEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return self._hash_value + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + + session = Session(cluster, [host]) + cluster.sessions.add(session) + + pool = Mock() + pool.get_state.return_value = {"open_count": 1} + + host.endpoint = _DeterministicHashEndPoint("127.0.0.1", 1) + session._pools = {host: pool} + host.endpoint = _DeterministicHashEndPoint("127.0.0.2", 2) + + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(host, is_host_addition=False) + + assert host.is_up is True + cluster.on_down_potentially_blocking.assert_not_called() + assert session.get_pool_state_for_host(host) == {"open_count": 1} + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): @@ -595,6 +943,26 @@ def test_no_profiles_same_name(self): with pytest.raises(ValueError): cluster.add_execution_profile('two', ExecutionProfile()) + def test_add_execution_profile_seeds_current_control_host(self): + cluster = Cluster(protocol_version=4) + self.addCleanup(cluster.shutdown) + + hosts = [ + Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + ] + for host in hosts: + host.set_up() + cluster.metadata.add_or_return_host(host) + + cluster.control_connection._connection = Mock(endpoint=hosts[1].endpoint) + cluster.control_connection._current_host_id = hosts[1].host_id + + profile = ExecutionProfile(load_balancing_policy=DynamicWhiteListRoundRobinPolicy()) + cluster.add_execution_profile('proxy', profile) + + assert list(profile.load_balancing_policy.make_query_plan()) == [hosts[1]] + def test_warning_on_no_lbp_with_contact_points_legacy_mode(self): """ Test that users are warned when they instantiate a Cluster object in diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..99ae4282a6 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -102,6 +102,7 @@ class MockCluster(object): down_host = None contact_points = [] is_shutdown = False + sessions = () def __init__(self): self.metadata = MockMetadata() @@ -118,6 +119,14 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, self.added_hosts.append(host) return host, True + def _update_host_endpoint(self, host, endpoint): + old_endpoint = host.endpoint + host.endpoint = endpoint + self.metadata.update_host(host, old_endpoint) + + def _get_control_connection_host_endpoint(self, host, connection_endpoint): + return connection_endpoint + def remove_host(self, host): pass @@ -301,6 +310,60 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_on_down_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + + self.control_connection.on_down(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + + def test_on_remove_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + self.control_connection.refresh_node_list_and_token_map = Mock() + + self.control_connection.on_remove(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + self.control_connection.refresh_node_list_and_token_map.assert_not_called() + + def test_signal_error_marks_current_host_down_when_current_host_matches_by_host_id(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock() + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_not_called() + + def test_signal_error_reconnects_when_current_host_conviction_is_deferred(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock(return_value=False) + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata @@ -319,6 +382,32 @@ def test_refresh_nodes_and_tokens(self): assert self.connection.wait_for_responses.call_count == 1 + def test_refresh_nodes_and_tokens_rebinds_current_host_to_control_endpoint(self): + proxy_endpoint = DefaultEndPoint("127.254.254.101") + self.connection.endpoint = proxy_endpoint + self.connection.original_endpoint = proxy_endpoint + self.cluster.profile_manager.on_control_connection_host = Mock() + + self.control_connection.refresh_node_list_and_token_map() + + current_host = self.cluster.metadata.get_host_by_host_id("uuid1") + assert current_host.endpoint == proxy_endpoint + assert self.cluster.metadata.get_host(proxy_endpoint) is current_host + assert self.control_connection._current_host_id == "uuid1" + self.cluster.profile_manager.on_control_connection_host.assert_called_once_with(current_host) + + def test_refresh_nodes_and_tokens_skips_intermediate_endpoint_for_current_host(self): + proxy_endpoint = DefaultEndPoint("127.254.254.101") + self.connection.endpoint = proxy_endpoint + self.connection.original_endpoint = proxy_endpoint + self.control_connection.refresh_node_list_and_token_map() + + self.cluster._update_host_endpoint = Mock(wraps=self.cluster._update_host_endpoint) + + self.control_connection.refresh_node_list_and_token_map() + + assert self.cluster._update_host_endpoint.call_args_list == [] + def test_refresh_nodes_and_tokens_with_invalid_peers(self): def refresh_and_validate_added_hosts(): self.connection.wait_for_responses = Mock(return_value=_node_meta_results( diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..b2109565ea 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -21,7 +21,7 @@ import unittest from threading import Thread, Event, Lock -from unittest.mock import Mock, NonCallableMagicMock, MagicMock +from unittest.mock import Mock, NonCallableMagicMock, MagicMock, patch from cassandra.cluster import Session, ShardAwareOptions from cassandra.connection import Connection @@ -42,6 +42,8 @@ class _PoolTests(unittest.TestCase): def make_session(self): session = NonCallableMagicMock(spec=Session, keyspace='foobarkeyspace', _trash=[]) + session._signal_connection_failure.return_value = False + session.is_shard_aware_disabled.return_value = False return session def test_borrow_and_return(self): @@ -143,8 +145,7 @@ def test_return_defunct_connection(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # the connection should be closed a new creation scheduled @@ -165,19 +166,14 @@ def test_return_defunct_connection_on_down_host(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = True - host.signal_connection_failure.return_value = True + session._signal_connection_failure.return_value = True pool.return_connection(conn) - # the connection should be closed a new creation scheduled + # the connection should be closed and the pool should delegate down + # handling back to the session. assert conn.close.call_args - if self.PoolImpl is HostConnection: - # on shard aware implementation we use submit function regardless - assert host.signal_connection_failure.call_args - assert session.submit.called - else: - assert not session.submit.called - assert session.cluster.signal_connection_failure.call_args + session._signal_connection_failure.assert_called_once_with(host, conn.last_error) + session._handle_pool_down.assert_called_once_with(host, is_host_addition=False) assert pool.is_shutdown def test_return_closed_connection(self): @@ -192,8 +188,7 @@ def test_return_closed_connection(self): pool.borrow_connection(timeout=0.01) conn.is_closed = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # a new creation should be scheduled @@ -231,6 +226,29 @@ class HostConnectionTests(_PoolTests): PoolImpl = HostConnection uses_single_connection = True + def test_session_level_shard_aware_disable_skips_fanout(self): + host = Mock(spec=Host, address='ip1') + host.sharding_info = None + session = self.make_session() + session.is_shard_aware_disabled.return_value = True + + connection = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100) + connection.features = ProtocolFeatures( + shard_id=0, + sharding_info=_ShardingInfo( + shard_id=0, shards_count=4, partitioner="", + sharding_algorithm="", sharding_ignore_msb=0, + shard_aware_port=19042, shard_aware_port_ssl=""), + tablets_routing_v1=False) + session.cluster.connection_factory.return_value = connection + + with patch.object(HostConnection, "_open_connections_for_all_shards") as open_shards: + pool = HostConnection(host, HostDistance.LOCAL, session) + + open_shards.assert_not_called() + assert pool.host.sharding_info is None + def test_fast_shutdown(self): class MockSession(MagicMock): is_shutdown = False diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..073612178b 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -27,7 +27,8 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, @@ -1421,6 +1422,42 @@ def test_hosts_with_hostname(self): assert policy.distance(host) == HostDistance.LOCAL + +class DynamicWhiteListRoundRobinPolicyTest(unittest.TestCase): + + def test_control_connection_host_updates_allowed_host(self): + hosts = [ + Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + ] + for host in hosts: + host.set_up() + + cluster = Mock() + cluster.metadata.all_hosts.return_value = hosts + + policy = DynamicWhiteListRoundRobinPolicy() + policy.populate(cluster, hosts) + + assert list(policy.make_query_plan()) == [] + assert policy.distance(hosts[0]) == HostDistance.IGNORED + + policy.on_control_connection_host(hosts[1]) + + assert list(policy.make_query_plan()) == [hosts[1]] + assert policy.distance(hosts[0]) == HostDistance.IGNORED + assert policy.distance(hosts[1]) == HostDistance.LOCAL + + policy.on_down(hosts[1]) + assert list(policy.make_query_plan()) == [] + + policy.on_up(hosts[1]) + assert list(policy.make_query_plan()) == [hosts[1]] + + policy.on_control_connection_host(hosts[2]) + assert list(policy.make_query_plan()) == [hosts[2]] + def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts) diff --git a/tests/unit/test_shard_aware.py b/tests/unit/test_shard_aware.py index 4b4c2c138d..35bd24676f 100644 --- a/tests/unit/test_shard_aware.py +++ b/tests/unit/test_shard_aware.py @@ -53,6 +53,9 @@ def submit(self, fn, *args, **kwargs): self.futures += [f] return f + def is_shard_aware_disabled(self): + return self.cluster.shard_aware_options.disable + def mock_connection_factory(self, *args, **kwargs): connection = MagicMock() connection.is_shutdown = False From 0a4f514939e06d9068d9e1cf24b691183fca6725 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 13:38:38 -0400 Subject: [PATCH 02/14] control-connection: refresh pools on host change --- cassandra/cluster.py | 2 +- tests/unit/test_control_connection.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5ef7053806..aa030796cd 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4073,7 +4073,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._current_host_id = local_host_id if current_host is not None else None self._cluster.profile_manager.on_control_connection_host(current_host) if connection is self._connection and self._current_host_id != previous_current_host_id: - for session in tuple(getattr(self, "sessions", ())): + for session in tuple(getattr(self._cluster, "sessions", ())): session.update_created_pools() if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 99ae4282a6..0e4b3aa182 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -396,6 +396,16 @@ def test_refresh_nodes_and_tokens_rebinds_current_host_to_control_endpoint(self) assert self.control_connection._current_host_id == "uuid1" self.cluster.profile_manager.on_control_connection_host.assert_called_once_with(current_host) + def test_refresh_nodes_and_tokens_updates_sessions_when_current_host_changes(self): + session = Mock() + self.cluster.sessions = (session,) + self.control_connection._current_host_id = "uuid2" + + self.control_connection.refresh_node_list_and_token_map() + + assert self.control_connection._current_host_id == "uuid1" + session.update_created_pools.assert_called_once_with() + def test_refresh_nodes_and_tokens_skips_intermediate_endpoint_for_current_host(self): proxy_endpoint = DefaultEndPoint("127.254.254.101") self.connection.endpoint = proxy_endpoint From e6451be1198dc166e263e7ffd0b822a05e671e01 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 14:50:25 -0400 Subject: [PATCH 03/14] Fix proxy host refresh regressions --- cassandra/cluster.py | 53 ++++++++++++++++++++++----- cassandra/policies.py | 18 +++++++-- tests/unit/test_cluster.py | 31 ++++++++++++++++ tests/unit/test_control_connection.py | 34 +++++++++++++++++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index aa030796cd..59229e6398 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1945,12 +1945,29 @@ def _update_host_endpoint(self, host, endpoint): host.endpoint = endpoint self.metadata.update_host(host, old_endpoint) if was_up: - host.set_up() self.profile_manager.on_up(host) + futures_lock = Lock() + futures_results = [] + futures = set() + callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback_futures = [] for session in tuple(self.sessions): - session.add_or_renew_pool(host, is_host_addition=False) - for listener in self.listeners: - listener.on_up(host) + future = session.add_or_renew_pool(host, is_host_addition=False) + if future is not None: + futures.add(future) + callback_futures.append(future) + + if callback_futures: + with host.lock: + host._currently_handling_node_up = True + for future in callback_futures: + future.add_done_callback(callback) + else: + host.set_up() + for listener in self.listeners: + listener.on_up(host) + for session in tuple(self.sessions): + session.update_created_pools() else: self._start_reconnector(host, is_host_addition=False) @@ -3669,6 +3686,7 @@ def __init__(self, cluster, timeout, self._reconnection_handler = None self._reconnection_lock = RLock() self._current_host_id = None + self._pending_control_connection_hosts = {} self._event_schedule_times = {} @@ -3691,8 +3709,14 @@ def _set_new_connection(self, conn): if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) + self._pending_control_connection_hosts.pop(id(old), None) old.close() + current_host_info = self._pending_control_connection_hosts.pop(id(conn), None) + if current_host_info is not None: + current_host, current_host_id = current_host_info + self._set_current_control_connection_host(current_host, current_host_id, update_sessions=False) + for session in tuple(getattr(self._cluster, "sessions", ())): session.update_created_pools() @@ -3825,6 +3849,7 @@ def _try_connect(self, endpoint): self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: + self._pending_control_connection_hosts.pop(id(connection), None) connection.close() raise @@ -3898,10 +3923,11 @@ def shutdown(self): self._connection.close() self._connection = None self._current_host_id = None + self._pending_control_connection_hosts.clear() try: self._cluster.profile_manager.on_control_connection_host(None) except ReferenceError: - pass + pass # our weak reference to the Cluster is no good def refresh_schema(self, force=False, **kwargs): try: @@ -4069,15 +4095,22 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._maybe_rebind_control_connection_host_endpoint(current_host, connection.endpoint) current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) + current_host_id = local_host_id if current_host is not None else None + if connection is self._connection: + self._set_current_control_connection_host(current_host, current_host_id) + else: + self._pending_control_connection_hosts[id(connection)] = (current_host, current_host_id) + if partitioner and should_rebuild_token_map: + log.debug("[control connection] Rebuilding token map due to topology changes") + self._cluster.metadata.rebuild_token_map(partitioner, token_map) + + def _set_current_control_connection_host(self, current_host, current_host_id, update_sessions=True): previous_current_host_id = self._current_host_id - self._current_host_id = local_host_id if current_host is not None else None + self._current_host_id = current_host_id self._cluster.profile_manager.on_control_connection_host(current_host) - if connection is self._connection and self._current_host_id != previous_current_host_id: + if update_sessions and self._current_host_id != previous_current_host_id: for session in tuple(getattr(self._cluster, "sessions", ())): session.update_created_pools() - if partitioner and should_rebuild_token_map: - log.debug("[control connection] Rebuilding token map due to topology changes") - self._cluster.metadata.rebuild_token_map(partitioner, token_map) @staticmethod def _is_valid_peer(row): diff --git a/cassandra/policies.py b/cassandra/policies.py index 268a260812..2c70397d88 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -551,7 +551,11 @@ def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) def on_control_connection_host(self, host): - return self._child_policy.on_control_connection_host(host) + on_control_connection_host = getattr( + self._child_policy, 'on_control_connection_host', None) + if on_control_connection_host is not None: + return on_control_connection_host(host) + return None class WhiteListRoundRobinPolicy(RoundRobinPolicy): @@ -720,7 +724,11 @@ def on_remove(self, host, *args, **kwargs): return self._child_policy.on_remove(host, *args, **kwargs) def on_control_connection_host(self, host): - return self._child_policy.on_control_connection_host(host) + on_control_connection_host = getattr( + self._child_policy, 'on_control_connection_host', None) + if on_control_connection_host is not None: + return on_control_connection_host(host) + return None @property def predicate(self): @@ -1391,7 +1399,11 @@ def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) def on_control_connection_host(self, host): - return self._child_policy.on_control_connection_host(host) + on_control_connection_host = getattr( + self._child_policy, 'on_control_connection_host', None) + if on_control_connection_host is not None: + return on_control_connection_host(host) + return None class DefaultLoadBalancingPolicy(WrapperPolicy): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 16e28322d6..7c663153ff 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -16,6 +16,7 @@ import logging import socket from concurrent.futures import Future +from functools import total_ordering from unittest.mock import patch, Mock import uuid @@ -36,6 +37,7 @@ log = logging.getLogger(__name__) +@total_ordering class _HostAwareProxyEndPoint(EndPoint): def __init__(self, address, affinity_key, port=9042): self._address = address @@ -358,6 +360,9 @@ def test_update_host_endpoint_notifies_listeners_for_live_host(self): cluster.metadata.add_or_return_host(host) session = Mock() + add_future = Future() + add_future.set_result(True) + session.add_or_renew_pool.return_value = add_future cluster.sessions.add(session) listener = _RecordingHostStateListener() @@ -370,6 +375,31 @@ def test_update_host_endpoint_notifies_listeners_for_live_host(self): session.remove_pool.assert_called_once_with(host) session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + def test_update_host_endpoint_restarts_reconnector_when_replacement_pool_fails(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + add_future = Future() + add_future.set_result(False) + + session = Mock() + session.add_or_renew_pool.return_value = add_future + cluster.sessions.add(session) + cluster._start_reconnector = Mock() + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert host.endpoint == new_endpoint + assert cluster.metadata.get_host(new_endpoint) is host + assert host.is_up is False + session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + cluster._start_reconnector.assert_called_once_with(host, is_host_addition=False) + def test_is_shard_aware_ignores_non_shard_aware_pools(self): cluster = Cluster(contact_points=['127.0.0.1']) @@ -636,6 +666,7 @@ def test_analytics_master_lookup_keeps_explicit_host(self): @mock_session_pools def test_session_preserves_down_event_discounting_after_endpoint_update(self, *_): + @total_ordering class _DeterministicHashEndPoint(EndPoint): def __init__(self, address, hash_value, port=9042): self._address = address diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 0e4b3aa182..667fe88f07 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -23,6 +23,7 @@ from cassandra.pool import Host from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) PEER_IP = "foobar" @@ -406,6 +407,39 @@ def test_refresh_nodes_and_tokens_updates_sessions_when_current_host_changes(sel assert self.control_connection._current_host_id == "uuid1" session.update_created_pools.assert_called_once_with() + def test_candidate_refresh_keeps_dynamic_whitelist_on_active_connection_until_adopted(self): + policy = DynamicWhiteListRoundRobinPolicy() + self.cluster.profile_manager = ProfileManager() + self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(policy) + policy.populate(self.cluster, self.cluster.metadata.all_hosts()) + + active_host = self.cluster.metadata.get_host_by_host_id("uuid2") + policy.on_control_connection_host(active_host) + + active_connection = MockConnection() + active_connection.endpoint = active_host.endpoint + active_connection.original_endpoint = active_host.endpoint + active_connection.close = Mock() + self.control_connection._connection = active_connection + self.control_connection._current_host_id = active_host.host_id + + session = Mock() + self.cluster.sessions = (session,) + + candidate_connection = MockConnection() + self.control_connection._refresh_node_list_and_token_map(candidate_connection) + + assert self.control_connection._current_host_id == active_host.host_id + assert list(policy.make_query_plan()) == [active_host] + session.update_created_pools.assert_not_called() + + self.control_connection._set_new_connection(candidate_connection) + + candidate_host = self.cluster.metadata.get_host_by_host_id("uuid1") + assert self.control_connection._current_host_id == candidate_host.host_id + assert list(policy.make_query_plan()) == [candidate_host] + session.update_created_pools.assert_called_once_with() + def test_refresh_nodes_and_tokens_skips_intermediate_endpoint_for_current_host(self): proxy_endpoint = DefaultEndPoint("127.254.254.101") self.connection.endpoint = proxy_endpoint From 867b29e307240b2318656d0f4cd6aa0e6d8edcde Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 4 May 2026 16:08:44 -0400 Subject: [PATCH 04/14] control-connection: mark reconnected host up --- cassandra/cluster.py | 5 +++ tests/unit/test_control_connection.py | 63 ++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 59229e6398..2f601f2ca3 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4108,6 +4108,11 @@ def _set_current_control_connection_host(self, current_host, current_host_id, up previous_current_host_id = self._current_host_id self._current_host_id = current_host_id self._cluster.profile_manager.on_control_connection_host(current_host) + if current_host is not None: + if current_host.is_up is None: + current_host.set_up() + elif current_host.is_up is False: + self._cluster.on_up(current_host) if update_sessions and self._current_host_id != previous_current_host_id: for session in tuple(getattr(self._cluster, "sessions", ())): session.update_created_pools() diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 667fe88f07..b79640d1de 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,11 +15,11 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call, patch from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, @@ -440,6 +440,65 @@ def test_candidate_refresh_keeps_dynamic_whitelist_on_active_connection_until_ad assert list(policy.make_query_plan()) == [candidate_host] session.update_created_pools.assert_called_once_with() + def test_initial_dynamic_whitelist_control_host_down_event_is_handled(self): + policy = DynamicWhiteListRoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + self.addCleanup(cluster.shutdown) + + connection = MockConnection() + connection.endpoint = DefaultEndPoint("127.254.254.101") + connection.original_endpoint = connection.endpoint + connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=connection.endpoint): + cluster.control_connection._refresh_node_list_and_token_map(connection) + cluster.control_connection._set_new_connection(connection) + cluster._populate_hosts() + + current_host = cluster.metadata.get_host_by_host_id("uuid1") + assert current_host.is_up is True + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(current_host, is_host_addition=False) + + cluster.on_down_potentially_blocking.assert_called_once_with( + current_host, False) + + def test_dynamic_whitelist_reconnected_control_host_is_marked_up(self): + policy = DynamicWhiteListRoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + self.addCleanup(cluster.shutdown) + + connection = MockConnection() + connection.endpoint = DefaultEndPoint("127.254.254.101") + connection.original_endpoint = connection.endpoint + connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=connection.endpoint): + cluster.control_connection._refresh_node_list_and_token_map(connection) + cluster.control_connection._set_new_connection(connection) + cluster._populate_hosts() + + current_host = cluster.metadata.get_host_by_host_id("uuid1") + current_host.set_down() + policy.on_down(current_host) + assert list(policy.make_query_plan()) == [] + + new_connection = MockConnection() + new_connection.endpoint = connection.endpoint + new_connection.original_endpoint = new_connection.endpoint + new_connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=new_connection.endpoint): + cluster.control_connection._refresh_node_list_and_token_map(new_connection) + cluster.control_connection._set_new_connection(new_connection) + + assert current_host.is_up is True + assert list(policy.make_query_plan()) == [current_host] + def test_refresh_nodes_and_tokens_skips_intermediate_endpoint_for_current_host(self): proxy_endpoint = DefaultEndPoint("127.254.254.101") self.connection.endpoint = proxy_endpoint From 7a74a2bfc34489163151c5cf6c57ec5761440747 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 06:35:10 -0400 Subject: [PATCH 05/14] control-connection: preserve proxy endpoint ownership --- cassandra/cluster.py | 80 +++++++++++++++-- tests/unit/test_cluster.py | 41 ++++++++- tests/unit/test_control_connection.py | 118 +++++++++++++++++++++++++- 3 files changed, 228 insertions(+), 11 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2f601f2ca3..f3abca2930 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -550,7 +550,10 @@ def on_remove(self, host): def on_control_connection_host(self, host): for p in self.profiles.values(): - p.load_balancing_policy.on_control_connection_host(host) + on_control_connection_host = getattr( + p.load_balancing_policy, 'on_control_connection_host', None) + if on_control_connection_host is not None: + on_control_connection_host(host) @property def default(self): @@ -1677,7 +1680,10 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5): self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) - profile.load_balancing_policy.on_control_connection_host(self.get_control_connection_host()) + on_control_connection_host = getattr( + profile.load_balancing_policy, 'on_control_connection_host', None) + if on_control_connection_host is not None: + on_control_connection_host(self.get_control_connection_host()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) @@ -3714,8 +3720,14 @@ def _set_new_connection(self, conn): current_host_info = self._pending_control_connection_hosts.pop(id(conn), None) if current_host_info is not None: - current_host, current_host_id = current_host_info + current_host, current_host_id, refreshed_endpoints = current_host_info + displaced_host_info = self._stage_displaced_control_host_endpoint( + current_host_id, refreshed_endpoints) + if current_host is not None: + self._maybe_rebind_control_connection_host_endpoint(current_host, conn.endpoint) + current_host = self._cluster.metadata.get_host_by_host_id(current_host_id) self._set_current_control_connection_host(current_host, current_host_id, update_sessions=False) + self._finish_displaced_control_host_endpoint(displaced_host_info) for session in tuple(getattr(self._cluster, "sessions", ())): session.update_created_pools() @@ -4006,8 +4018,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids = set() found_endpoints = set() + refreshed_endpoints = {} local_host_id = None + refreshing_active_connection = connection is self._connection if local_result.parsed_rows: local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] @@ -4041,6 +4055,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids.add(host_id) found_endpoints.add(endpoint) + refreshed_endpoints[host_id] = endpoint host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") @@ -4050,9 +4065,14 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host: target_endpoint = endpoint if host_id == local_host_id: - target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection.endpoint) - if target_endpoint is None: - target_endpoint = endpoint + if refreshing_active_connection: + target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection.endpoint) + if target_endpoint is None: + target_endpoint = endpoint + else: + target_endpoint = host.endpoint + elif not refreshing_active_connection and host_id == self._current_host_id: + target_endpoint = host.endpoint if host.endpoint != target_endpoint: log.debug("[control connection] Updating host endpoint from %s to %s for (%s)", @@ -4091,7 +4111,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, current_host = None if local_host_id in found_host_ids: current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) - if current_host is not None: + if current_host is not None and refreshing_active_connection: self._maybe_rebind_control_connection_host_endpoint(current_host, connection.endpoint) current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) @@ -4099,11 +4119,55 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if connection is self._connection: self._set_current_control_connection_host(current_host, current_host_id) else: - self._pending_control_connection_hosts[id(connection)] = (current_host, current_host_id) + self._pending_control_connection_hosts[id(connection)] = ( + current_host, current_host_id, refreshed_endpoints) if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) + def _stage_displaced_control_host_endpoint(self, current_host_id, refreshed_endpoints): + if current_host_id is None: + return None + + previous_host_id = self._current_host_id + if previous_host_id is None or previous_host_id == current_host_id: + return None + + previous_host = self._cluster.metadata.get_host_by_host_id(previous_host_id) + previous_endpoint = refreshed_endpoints.get(previous_host_id) + if previous_host is None or previous_endpoint is None or previous_host.endpoint == previous_endpoint: + return None + + was_up = previous_host.is_up + reconnector = previous_host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + + if was_up: + previous_host.set_down() + self._cluster.profile_manager.on_down(previous_host) + for session in tuple(getattr(self._cluster, "sessions", ())): + session.remove_pool(previous_host) + if was_up: + for listener in self._cluster.listeners: + listener.on_down(previous_host) + + old_endpoint = previous_host.endpoint + previous_host.endpoint = previous_endpoint + self._cluster.metadata.update_host(previous_host, old_endpoint) + return previous_host, was_up + + def _finish_displaced_control_host_endpoint(self, displaced_host_info): + if displaced_host_info is None: + return + + displaced_host, was_up = displaced_host_info + if was_up: + self._cluster.profile_manager.on_up(displaced_host) + displaced_host.set_up() + for listener in self._cluster.listeners: + listener.on_up(displaced_host) + def _set_current_control_connection_host(self, current_host, current_host_id, update_sessions=True): previous_current_host_id = self._current_host_id self._current_host_id = current_host_id diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 7c663153ff..aa04194f29 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -27,7 +27,7 @@ from cassandra.connection import DefaultEndPoint, EndPoint from cassandra.pool import Host from cassandra.policies import RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, \ - DynamicWhiteListRoundRobinPolicy, HostStateListener, SimpleConvictionPolicy + DynamicWhiteListRoundRobinPolicy, HostDistance, HostStateListener, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools from tests import connection_class @@ -88,6 +88,33 @@ def on_remove(self, host): self.events.append(("remove", host.address)) +class _DuckTypedLoadBalancingPolicy(object): + + def populate(self, cluster, hosts): + self.hosts = tuple(hosts) + + def check_supported(self): + pass + + def distance(self, host): + return HostDistance.IGNORED + + def make_query_plan(self, working_keyspace=None, query=None): + return iter(getattr(self, "hosts", ())) + + def on_up(self, host): + pass + + def on_down(self, host): + pass + + def on_add(self, host): + pass + + def on_remove(self, host): + pass + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -263,6 +290,16 @@ def test_compression_type_validation(self): with pytest.raises(TypeError): Cluster(compression=123) + def test_shutdown_before_connect_tolerates_policy_without_control_host_hook(self): + policy = _DuckTypedLoadBalancingPolicy() + cluster = Cluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=policy) + }, + protocol_version=4) + + cluster.shutdown() + def test_connection_factory_passes_compression_kwarg(self): endpoint = Mock(address='127.0.0.1') scenarios = [ @@ -642,7 +679,7 @@ def test_analytics_master_lookup_keeps_explicit_host(self): target_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) target_host.set_up() - with patch.object(Session, "_add_or_renew_pool_for_distance", + with patch.object(Session, "add_or_renew_pool", return_value=self._completed_future(True)): session = Session(cluster, [target_host]) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index b79640d1de..71de383b4e 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -14,7 +14,7 @@ import unittest -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from unittest.mock import Mock, ANY, call, patch from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType @@ -440,6 +440,122 @@ def test_candidate_refresh_keeps_dynamic_whitelist_on_active_connection_until_ad assert list(policy.make_query_plan()) == [candidate_host] session.update_created_pools.assert_called_once_with() + def test_candidate_refresh_does_not_renew_pools_until_connection_is_adopted(self): + policy = DynamicWhiteListRoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + self.addCleanup(cluster.shutdown) + + active_connection = MockConnection() + active_connection.endpoint = DefaultEndPoint("127.254.254.101") + active_connection.original_endpoint = active_connection.endpoint + active_connection.close = Mock() + cluster.control_connection._connection = active_connection + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=active_connection.endpoint): + cluster.control_connection._refresh_node_list_and_token_map(active_connection) + cluster._populate_hosts() + + active_host = cluster.metadata.get_host_by_host_id("uuid1") + assert active_host.endpoint == active_connection.endpoint + assert list(policy.make_query_plan()) == [active_host] + + completed = Future() + completed.set_result(True) + session = Mock() + session.remove_pool.return_value = None + session.add_or_renew_pool.return_value = completed + session.update_created_pools.return_value = set() + cluster.sessions.add(session) + + candidate_connection = MockConnection() + candidate_connection.endpoint = DefaultEndPoint("127.254.254.102") + candidate_connection.original_endpoint = candidate_connection.endpoint + candidate_connection.local_results = [ + ["rpc_address", "schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["192.168.1.1", "a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["1", "101", "201"], "uuid2"]] + ] + candidate_connection.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.0", "10.0.0.0", "a", "dc1", "rack1", ["0", "100", "200"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + ] + candidate_connection.wait_for_responses = Mock( + return_value=_node_meta_results(candidate_connection.local_results, + candidate_connection.peer_results)) + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=candidate_connection.endpoint): + cluster.control_connection._refresh_node_list_and_token_map(candidate_connection) + + assert cluster.control_connection._current_host_id == "uuid1" + assert list(policy.make_query_plan()) == [active_host] + session.remove_pool.assert_not_called() + session.add_or_renew_pool.assert_not_called() + session.update_created_pools.assert_not_called() + + def test_adopting_candidate_on_same_proxy_endpoint_restores_previous_host_endpoint(self): + policy = DynamicWhiteListRoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + self.addCleanup(cluster.shutdown) + + proxy_endpoint = DefaultEndPoint("127.254.254.101") + + active_connection = MockConnection() + active_connection.endpoint = proxy_endpoint + active_connection.original_endpoint = proxy_endpoint + active_connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(active_connection) + cluster.control_connection._set_new_connection(active_connection) + cluster._populate_hosts() + + previous_host = cluster.metadata.get_host_by_host_id("uuid1") + assert previous_host.endpoint == proxy_endpoint + + removed_hosts = [] + session = Mock() + session.remove_pool.side_effect = lambda host: removed_hosts.append( + (host.host_id, host.endpoint)) + session.update_created_pools.return_value = set() + cluster.sessions.add(session) + + candidate_connection = MockConnection() + candidate_connection.endpoint = proxy_endpoint + candidate_connection.original_endpoint = proxy_endpoint + candidate_connection.close = Mock() + candidate_connection.local_results = [ + ["rpc_address", "schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["192.168.1.1", "a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["1", "101", "201"], "uuid2"]] + ] + candidate_connection.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.0", "10.0.0.0", "a", "dc1", "rack1", ["0", "100", "200"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + ] + candidate_connection.wait_for_responses = Mock( + return_value=_node_meta_results(candidate_connection.local_results, + candidate_connection.peer_results)) + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(candidate_connection) + cluster.control_connection._set_new_connection(candidate_connection) + + current_host = cluster.metadata.get_host_by_host_id("uuid2") + previous_host = cluster.metadata.get_host_by_host_id("uuid1") + + assert current_host.endpoint == proxy_endpoint + assert previous_host.endpoint == DefaultEndPoint("192.168.1.0") + assert cluster.metadata.get_host(proxy_endpoint) is current_host + assert removed_hosts[:2] == [ + ("uuid1", proxy_endpoint), + ("uuid2", DefaultEndPoint("192.168.1.1")), + ] + session.update_created_pools.assert_called_once_with() + def test_initial_dynamic_whitelist_control_host_down_event_is_handled(self): policy = DynamicWhiteListRoundRobinPolicy() cluster = Cluster(load_balancing_policy=policy, protocol_version=4) From 3dc0028a1b7c2fe870a98022b9646043799e6a8e Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 07:12:53 -0400 Subject: [PATCH 06/14] control-connection: honor endpoint verification mismatch --- cassandra/cluster.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index f3abca2930..a38b506f0e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1879,6 +1879,8 @@ def _default_control_connection_endpoint_targets_host( self, host, endpoint, attempts=3): for _ in range(attempts): connected_host_id = self._get_host_id_for_endpoint(endpoint) + if connected_host_id is None: + return None if connected_host_id != host.host_id: return False return True @@ -1915,9 +1917,14 @@ def _get_control_connection_host_endpoint(self, control_host, connection_endpoin if host_endpoint is not None and not isinstance(host_endpoint, DefaultEndPoint): return host_endpoint - if connection_endpoint is not None and self._default_control_connection_endpoint_targets_host( - control_host, connection_endpoint): + targets_host = None + if connection_endpoint is not None: + targets_host = self._default_control_connection_endpoint_targets_host( + control_host, connection_endpoint) + if targets_host: return connection_endpoint + if targets_host is False and host_endpoint is not None: + return host_endpoint if connection_endpoint is not None: return connection_endpoint From e714b4ab290c8a3e78ae51fdb4a84613549963ac Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 08:47:35 -0400 Subject: [PATCH 07/14] client-routes: preserve proxy route state --- cassandra/client_routes.py | 2 +- cassandra/cluster.py | 11 ++++--- cassandra/connection.py | 13 ++++---- tests/unit/test_client_routes.py | 47 +++++++++++++++++++++++++++ tests/unit/test_cluster.py | 4 +-- tests/unit/test_control_connection.py | 27 ++++++++++++++- 6 files changed, 89 insertions(+), 15 deletions(-) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index 80b2477a6d..d8cea8f9ba 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float, return routes = self._query_routes_for_change_event(connection, timeout, pairs) - self._routes.merge(routes, affected_host_ids=set(host_uuids)) + self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs}) def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, connection_ids: Set[str]) -> List[_Route]: diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a38b506f0e..2213a956c9 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2146,7 +2146,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): Intended for internal use only. """ if self.is_shutdown: - return + return False with host.lock: was_up = host.is_up @@ -2160,14 +2160,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if pool_state: connected |= pool_state['open_count'] > 0 if connected: - return + return False host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): - return + return False log.warning("Host %s has been marked down", host) self.on_down_potentially_blocking(host, is_host_addition) + return True def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2259,8 +2260,8 @@ def on_remove(self, host): def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: - self.on_down(host, is_host_addition, expect_host_to_be_down) - return is_down + return self.on_down(host, is_host_addition, expect_host_to_be_down) + return False def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): """ diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..d81efad4e7 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -468,21 +468,22 @@ def resolve(self) -> Tuple[str, int]: def __eq__(self, other): return (isinstance(other, ClientRoutesEndPoint) and self._host_id == other._host_id and - self._original_address == other._original_address) + self._original_address == other._original_address and + self._original_port == other._original_port) def __hash__(self): - return hash((self._host_id, self._original_address)) + return hash((self._host_id, self._original_address, self._original_port)) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return ((self._host_id, self._original_address, self._original_port) < + (other._host_id, other._original_address, other._original_port)) def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) def __repr__(self): - return "<%s: host_id=%s, original_addr=%s>" % ( - self.__class__.__name__, self._host_id, self._original_address) + return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address, self._original_port) class _Frame(object): diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 0aa82fc76a..953c01cca6 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -233,6 +233,35 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query): self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query): + """Routes for unrelated connection_ids in mixed events should not be removed.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + conn_id = str(self.conn_id) + changed_host = uuid.uuid4() + unrelated_host = uuid.uuid4() + + handler._routes.update([ + _Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042), + _Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id, "unrelated-conn-id"], + host_ids=[str(changed_host), str(unrelated_host)], + ) + + self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com") + self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com") + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') def test_handle_change_updates_when_no_host_ids(self, mock_query): """When no host_ids are provided, routes should be fully replaced.""" @@ -388,6 +417,24 @@ def test_resolve_host_missing_port_raises(self): with self.assertRaises(ValueError): self.handler.resolve_host(host_id) + def test_endpoint_identity_includes_original_port(self): + host_id = uuid.uuid4() + first = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + second = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9142, + ) + + self.assertNotEqual(first, second) + self.assertEqual(len({first, second}), 2) + class TestClientRoutesEndPointFactory(unittest.TestCase): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index aa04194f29..3c0959e8be 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -632,7 +632,7 @@ def test_get_control_connection_host_endpoint_prefers_host_aware_metadata_endpoi assert endpoint == source_endpoint - def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verification_mismatches(self): + def test_get_control_connection_host_endpoint_uses_metadata_endpoint_when_verification_mismatches(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) self.addCleanup(cluster.shutdown) @@ -649,7 +649,7 @@ def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verifi return_value=verification_connection) as connection_factory: endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) - assert endpoint == connection_endpoint + assert endpoint == source_host.endpoint assert connection_factory.call_count == 1 assert verification_connection.close.call_count == 1 diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 71de383b4e..801517a015 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import uuid from concurrent.futures import Future, ThreadPoolExecutor from unittest.mock import Mock, ANY, call, patch @@ -21,7 +22,7 @@ from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionException from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, DynamicWhiteListRoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -365,6 +366,30 @@ def test_signal_error_reconnects_when_current_host_conviction_is_deferred(self): host, error, is_host_addition=False) self.control_connection.reconnect.assert_called_once_with() + def test_signal_error_reconnects_when_host_down_signal_is_discounted(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + session = Mock() + session.get_pool_state_for_host.return_value = {"open_count": 1} + cluster.sessions.add(session) + + connection_error = ConnectionException("control connection failed", endpoint=host.endpoint) + cluster.control_connection._connection = Mock( + endpoint=host.endpoint, + is_defunct=True, + last_error=connection_error) + cluster.control_connection.reconnect = Mock() + + cluster.control_connection._signal_error() + + assert host.is_up is True + cluster.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata From 0ef854783c0b7f30c0ef4f16b487595b51f67596 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 09:37:16 -0400 Subject: [PATCH 08/14] control-connection: restart displaced host reconnector --- cassandra/cluster.py | 2 ++ tests/unit/test_control_connection.py | 51 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2213a956c9..943f8f0503 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4175,6 +4175,8 @@ def _finish_displaced_control_host_endpoint(self, displaced_host_info): displaced_host.set_up() for listener in self._cluster.listeners: listener.on_up(displaced_host) + else: + self._cluster._start_reconnector(displaced_host, is_host_addition=False) def _set_current_control_connection_host(self, current_host, current_host_id, update_sessions=True): previous_current_host_id = self._current_host_id diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 801517a015..dea4c24f23 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -581,6 +581,57 @@ def test_adopting_candidate_on_same_proxy_endpoint_restores_previous_host_endpoi ] session.update_created_pools.assert_called_once_with() + def test_adopting_candidate_restarts_displaced_down_host_reconnector(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + proxy_endpoint = DefaultEndPoint("127.254.254.101") + + active_connection = MockConnection() + active_connection.endpoint = proxy_endpoint + active_connection.original_endpoint = proxy_endpoint + active_connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(active_connection) + cluster.control_connection._set_new_connection(active_connection) + cluster._populate_hosts() + + previous_host = cluster.metadata.get_host_by_host_id("uuid1") + previous_host.set_down() + previous_reconnector = Mock() + previous_host.get_and_set_reconnection_handler(previous_reconnector) + + cluster._start_reconnector = Mock() + + candidate_connection = MockConnection() + candidate_connection.endpoint = proxy_endpoint + candidate_connection.original_endpoint = proxy_endpoint + candidate_connection.close = Mock() + candidate_connection.local_results = [ + ["rpc_address", "schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["192.168.1.1", "a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["1", "101", "201"], "uuid2"]] + ] + candidate_connection.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.0", "10.0.0.0", "a", "dc1", "rack1", ["0", "100", "200"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + ] + candidate_connection.wait_for_responses = Mock( + return_value=_node_meta_results(candidate_connection.local_results, + candidate_connection.peer_results)) + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(candidate_connection) + cluster.control_connection._set_new_connection(candidate_connection) + + assert previous_host.endpoint == DefaultEndPoint("192.168.1.0") + previous_reconnector.cancel.assert_called_once_with() + cluster._start_reconnector.assert_called_once_with( + previous_host, is_host_addition=False) + def test_initial_dynamic_whitelist_control_host_down_event_is_handled(self): policy = DynamicWhiteListRoundRobinPolicy() cluster = Cluster(load_balancing_policy=policy, protocol_version=4) From 156b9571465b809c13992272176548b4e8935f6d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 11:02:54 -0400 Subject: [PATCH 09/14] control-connection: handle displaced host renewal failure --- cassandra/cluster.py | 11 ++--- tests/unit/test_control_connection.py | 65 +++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 943f8f0503..6d5fd0a438 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1957,6 +1957,9 @@ def _update_host_endpoint(self, host, endpoint): old_endpoint = host.endpoint host.endpoint = endpoint self.metadata.update_host(host, old_endpoint) + self._finish_host_endpoint_update(host, was_up) + + def _finish_host_endpoint_update(self, host, was_up): if was_up: self.profile_manager.on_up(host) futures_lock = Lock() @@ -4170,13 +4173,7 @@ def _finish_displaced_control_host_endpoint(self, displaced_host_info): return displaced_host, was_up = displaced_host_info - if was_up: - self._cluster.profile_manager.on_up(displaced_host) - displaced_host.set_up() - for listener in self._cluster.listeners: - listener.on_up(displaced_host) - else: - self._cluster._start_reconnector(displaced_host, is_host_addition=False) + self._cluster._finish_host_endpoint_update(displaced_host, was_up) def _set_current_control_connection_host(self, current_host, current_host_id, update_sessions=True): previous_current_host_id = self._current_host_id diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index dea4c24f23..ff66d42bb9 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -632,6 +632,71 @@ def test_adopting_candidate_restarts_displaced_down_host_reconnector(self): cluster._start_reconnector.assert_called_once_with( previous_host, is_host_addition=False) + def test_adopting_candidate_restarts_displaced_up_host_reconnector_when_pool_renewal_fails(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + proxy_endpoint = DefaultEndPoint("127.254.254.101") + + active_connection = MockConnection() + active_connection.endpoint = proxy_endpoint + active_connection.original_endpoint = proxy_endpoint + active_connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(active_connection) + cluster.control_connection._set_new_connection(active_connection) + cluster._populate_hosts() + + previous_host = cluster.metadata.get_host_by_host_id("uuid1") + assert previous_host.is_up is True + + successful_pool_renewal = Future() + successful_pool_renewal.set_result(True) + failed_pool_renewal = Future() + failed_pool_renewal.set_result(False) + + def renew_pool(host, is_host_addition): + if host.host_id == "uuid1": + return failed_pool_renewal + return successful_pool_renewal + + session = Mock() + session.add_or_renew_pool.side_effect = renew_pool + session.update_created_pools.return_value = set() + cluster.sessions.add(session) + cluster._start_reconnector = Mock() + + candidate_connection = MockConnection() + candidate_connection.endpoint = proxy_endpoint + candidate_connection.original_endpoint = proxy_endpoint + candidate_connection.close = Mock() + candidate_connection.local_results = [ + ["rpc_address", "schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["192.168.1.1", "a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["1", "101", "201"], "uuid2"]] + ] + candidate_connection.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.0", "10.0.0.0", "a", "dc1", "rack1", ["0", "100", "200"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + ] + candidate_connection.wait_for_responses = Mock( + return_value=_node_meta_results(candidate_connection.local_results, + candidate_connection.peer_results)) + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(candidate_connection) + cluster.control_connection._set_new_connection(candidate_connection) + + assert previous_host.endpoint == DefaultEndPoint("192.168.1.0") + assert previous_host.is_up is False + session.add_or_renew_pool.assert_any_call( + previous_host, is_host_addition=False) + cluster._start_reconnector.assert_called_once_with( + previous_host, is_host_addition=False) + def test_initial_dynamic_whitelist_control_host_down_event_is_handled(self): policy = DynamicWhiteListRoundRobinPolicy() cluster = Cluster(load_balancing_policy=policy, protocol_version=4) From 85d75b248a1659d61e8fdd0592cfdbd762214928 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 11:39:23 -0400 Subject: [PATCH 10/14] control-connection: mark adopted host up through on_up --- cassandra/cluster.py | 4 +- tests/unit/test_control_connection.py | 64 +++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 6d5fd0a438..2ea2dcb27f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4180,9 +4180,7 @@ def _set_current_control_connection_host(self, current_host, current_host_id, up self._current_host_id = current_host_id self._cluster.profile_manager.on_control_connection_host(current_host) if current_host is not None: - if current_host.is_up is None: - current_host.set_up() - elif current_host.is_up is False: + if current_host.is_up is not True: self._cluster.on_up(current_host) if update_sessions and self._current_host_id != previous_current_host_id: for session in tuple(getattr(self._cluster, "sessions", ())): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index ff66d42bb9..76fc05694b 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -519,6 +519,70 @@ def test_candidate_refresh_does_not_renew_pools_until_connection_is_adopted(self session.add_or_renew_pool.assert_not_called() session.update_created_pools.assert_not_called() + def test_adopting_unknown_state_control_host_uses_on_up_flow(self): + policy = DynamicWhiteListRoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + self.addCleanup(cluster.shutdown) + + proxy_endpoint = DefaultEndPoint("127.254.254.101") + + active_connection = MockConnection() + active_connection.endpoint = proxy_endpoint + active_connection.original_endpoint = proxy_endpoint + active_connection.close = Mock() + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(active_connection) + cluster.control_connection._set_new_connection(active_connection) + cluster._populate_hosts() + + candidate_host = cluster.metadata.get_host_by_host_id("uuid2") + assert candidate_host.is_up is None + + completed = Future() + completed.set_result(True) + + def add_pool(host, is_host_addition): + if host.host_id == "uuid2": + return completed + return None + + session = Mock() + session.add_or_renew_pool.side_effect = add_pool + session.update_created_pools.return_value = set() + cluster.sessions.add(session) + + listener = Mock() + cluster.register_listener(listener) + + candidate_connection = MockConnection() + candidate_connection.endpoint = proxy_endpoint + candidate_connection.original_endpoint = proxy_endpoint + candidate_connection.close = Mock() + candidate_connection.local_results = [ + ["rpc_address", "schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["192.168.1.1", "a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["1", "101", "201"], "uuid2"]] + ] + candidate_connection.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.0", "10.0.0.0", "a", "dc1", "rack1", ["0", "100", "200"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + ] + candidate_connection.wait_for_responses = Mock( + return_value=_node_meta_results(candidate_connection.local_results, + candidate_connection.peer_results)) + + with patch.object(cluster, "_get_control_connection_host_endpoint", + return_value=proxy_endpoint): + cluster.control_connection._refresh_node_list_and_token_map(candidate_connection) + cluster.control_connection._set_new_connection(candidate_connection) + + assert candidate_host.is_up is True + session.add_or_renew_pool.assert_any_call( + candidate_host, is_host_addition=False) + listener.on_up.assert_any_call(candidate_host) + def test_adopting_candidate_on_same_proxy_endpoint_restores_previous_host_endpoint(self): policy = DynamicWhiteListRoundRobinPolicy() cluster = Cluster(load_balancing_policy=policy, protocol_version=4) From 00a97915d5007889bae2e08f6e005e584bbcb288 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 12:09:31 -0400 Subject: [PATCH 11/14] client-routes: handle endpoints without original port --- cassandra/connection.py | 7 +++++-- tests/unit/test_client_routes.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index d81efad4e7..c129bfb3a5 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -474,9 +474,12 @@ def __eq__(self, other): def __hash__(self): return hash((self._host_id, self._original_address, self._original_port)) + def _comparison_key(self): + return (self._host_id, self._original_address, + self._original_port is None, self._original_port) + def __lt__(self, other): - return ((self._host_id, self._original_address, self._original_port) < - (other._host_id, other._original_address, other._original_port)) + return self._comparison_key() < other._comparison_key() def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 953c01cca6..b5d2daa902 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -435,6 +435,24 @@ def test_endpoint_identity_includes_original_port(self): self.assertNotEqual(first, second) self.assertEqual(len({first, second}), 2) + def test_endpoint_ordering_handles_missing_original_port(self): + host_id = uuid.uuid4() + without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=None, + ) + with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertCountEqual( + sorted([without_port, with_port]), [without_port, with_port]) + class TestClientRoutesEndPointFactory(unittest.TestCase): From 2e7d784399eda66defc2114d915b0252372a2cc3 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 12:43:58 -0400 Subject: [PATCH 12/14] client-routes: preserve preferred proxy routes --- cassandra/client_routes.py | 20 +++++------ tests/unit/test_client_routes.py | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index d8cea8f9ba..e447e37df2 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -322,27 +322,25 @@ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: f def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: """ - Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + Query current routes for hosts affected by a CLIENT_ROUTES_CHANGE event. - Takes a list of (connection_id, host_id) pairs that represent the exact - routes affected by an operation. This provides precise updates without - fetching unrelated routes. - - If the pairs list is empty or None, falls back to a complete refresh - of all routes for safety. + The in-memory route store keeps a single preferred route per host. When + any configured connection_id changes for a host, fetch all configured + connection_ids for that host so the existing preferred route can be + retained if it is still present. :param connection: Connection to execute query on :param timeout: Query timeout in seconds - :param route_pairs: List of (connection_id, host_id) tuples + :param route_pairs: List of affected (connection_id, host_id) tuples :return: List of _Route """ unique_pairs = list(dict.fromkeys(route_pairs)) - conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + conn_ids = sorted(self._connection_ids) host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) - log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " - "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + log.debug("[client routes] Querying routes from CLIENT_ROUTES_CHANGE " + "for host_ids (first 5 of %d): %s", len(host_ids), host_ids[:5]) conn_ph = ', '.join('?' for _ in conn_ids) host_ph = ', '.join('?' for _ in host_ids) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index b5d2daa902..bca430c628 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -262,6 +262,63 @@ def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_ self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com") self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com") + def test_handle_change_preserves_preferred_route_for_same_host(self): + conn_a = str(uuid.uuid4()) + conn_b = str(uuid.uuid4()) + host_id = uuid.uuid4() + config = ClientRoutesConfig([ + ClientRouteProxy(conn_a), + ClientRouteProxy(conn_b), + ]) + handler = _ClientRoutesHandler(config) + handler._routes.update([ + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ]) + + table_routes = [ + _Route(connection_id=conn_a, host_id=host_id, + address="changed.example.com", port=9042), + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ] + + def wait_for_response(query_msg, timeout): + conn_placeholders = query_msg.query.split( + "connection_id IN (", 1)[1].split(")", 1)[0].count("?") + conn_ids = { + param.decode("utf-8") + for param in query_msg.query_params[:conn_placeholders] + } + host_ids = { + uuid.UUID(bytes=param) + for param in query_msg.query_params[conn_placeholders:] + } + rows = [ + (route.connection_id, route.host_id, route.address, + route.port, route.port) + for route in table_routes + if route.connection_id in conn_ids and route.host_id in host_ids + ] + return Mock( + column_names=["connection_id", "host_id", "address", "port", "tls_port"], + parsed_rows=rows, + ) + + mock_conn = Mock() + mock_conn.wait_for_response.side_effect = wait_for_response + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_a], + host_ids=[str(host_id)], + ) + + route = handler._routes.get_by_host_id(host_id) + self.assertEqual(route.connection_id, conn_b) + self.assertEqual(route.address, "current.example.com") + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') def test_handle_change_updates_when_no_host_ids(self, mock_query): """When no host_ids are provided, routes should be fully replaced.""" From 2d636dea61cf57282fb287fc92fa33bed2b39407 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 14:20:28 -0400 Subject: [PATCH 13/14] cluster: avoid duplicate pool renewal during host up --- cassandra/cluster.py | 4 +++- .../standard/test_client_routes.py | 15 +++++++------ tests/unit/test_cluster.py | 22 +++++++++++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2ea2dcb27f..fe58466eed 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3460,7 +3460,9 @@ def update_created_pools(self): # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): + if (distance != HostDistance.IGNORED and + host.is_up in (True, None) and + not getattr(host, '_currently_handling_node_up', False)): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index a9efcd65b4..fd3a40ffdd 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -618,12 +618,14 @@ class TestGetHostPortMapping(unittest.TestCase): @classmethod def setUpClass(cls): + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] + cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + proxies=[ClientRouteProxy(connection_id, "127.0.0.1") + for connection_id in cls.connection_ids])) cls.session = cls.cluster.connect() - cls.host_ids = [uuid.uuid4() for _ in range(3)] - cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] cls.expected = [] for idx, host_id in enumerate(cls.host_ids): @@ -712,8 +714,8 @@ def test_get_routes_for_change_event_all_pairs(self): self._sort_routes(expected) self.assertEqual(got, expected) - def test_get_routes_for_change_event_single_pair(self): - """Querying a single (connection_id, host_id) pair returns one route.""" + def test_get_routes_for_change_event_single_host(self): + """Querying a single changed host returns all configured routes for it.""" cc = self.cluster.control_connection target_conn_id = self.connection_ids[0] target_host_id = self.host_ids[0] @@ -723,8 +725,7 @@ def test_get_routes_for_change_event_single_pair(self): got = self._routes_to_dicts(routes) self._sort_routes(got) filtered = [r for r in self.expected - if r['connection_id'] == target_conn_id - and r['host_id'] == target_host_id] + if r['host_id'] == target_host_id] expected = self._expected_dicts(filtered) self._sort_routes(expected) self.assertEqual(got, expected) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 3c0959e8be..17c4748e24 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -701,6 +701,28 @@ def test_analytics_master_lookup_keeps_explicit_host(self): query_future._load_balancer.make_query_plan.assert_not_called() query_future.send_request.assert_called_once_with() + def test_update_created_pools_skips_host_with_node_up_in_progress(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.profile_manager.on_up(host) + + completed = Future() + completed.set_result(True) + + with patch.object(Session, "add_or_renew_pool", return_value=completed) as add_or_renew_pool: + session = Session(cluster, [host]) + add_or_renew_pool.reset_mock() + + session._pools = {} + host._currently_handling_node_up = True + + assert session.update_created_pools() == set() + add_or_renew_pool.assert_not_called() + @mock_session_pools def test_session_preserves_down_event_discounting_after_endpoint_update(self, *_): @total_ordering From 589a752e63a393d9118f27a618c27d14ed5450cc Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 15:39:20 -0400 Subject: [PATCH 14/14] cluster: notify listeners on up without pools --- cassandra/cluster.py | 4 ++++ tests/unit/test_cluster.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index fe58466eed..cd8069f791 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2104,6 +2104,10 @@ def on_up(self, host): with host.lock: host.set_up() host._currently_handling_node_up = False + for listener in self.listeners: + listener.on_up(host) + for session in tuple(self.sessions): + session.update_created_pools() # for testing purposes return futures diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 17c4748e24..d21cb1b937 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -412,6 +412,22 @@ def test_update_host_endpoint_notifies_listeners_for_live_host(self): session.remove_pool.assert_called_once_with(host) session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + def test_on_up_without_pool_futures_notifies_listeners(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_down() + cluster.metadata.add_or_return_host(host) + + listener = _RecordingHostStateListener() + cluster.register_listener(listener) + + cluster.on_up(host) + + assert host.is_up is True + assert listener.events == [("up", "127.0.0.1")] + def test_update_host_endpoint_restarts_reconnector_when_replacement_pool_fails(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) self.addCleanup(cluster.shutdown)