From 5bfa9e1b5bc2b68adb49dd9bb9edd5f4c5f174e2 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 20:49:05 -0400 Subject: [PATCH 1/2] control-connection: tolerate schema agreement connection loss --- cassandra/cluster.py | 169 ++++++++++++++++++++++++-- tests/unit/test_control_connection.py | 73 ++++++++++- 2 files changed, 228 insertions(+), 14 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..093c9cd884 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3531,6 +3531,7 @@ def __init__(self, cluster, timeout, # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) self._connection = None + self._last_connection_endpoint = None self._timeout = timeout self._schema_event_refresh_window = schema_event_refresh_window @@ -3564,6 +3565,7 @@ def _set_new_connection(self, conn): with self._lock: old = self._connection self._connection = conn + self._last_connection_endpoint = conn.endpoint if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) @@ -4094,33 +4096,28 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai if not connection: connection = self._connection + local_address = self._schema_agreement_endpoint(connection) if preloaded_results: log.debug("[control connection] Attempting to use preloaded results for schema agreement") peers_result = preloaded_results[0] local_result = preloaded_results[1] - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, local_address) if schema_mismatches is None: return True log.debug("[control connection] Waiting for schema agreement") start = self._time.time() elapsed = 0 - cl = ConsistencyLevel.ONE schema_mismatches = None - select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) while elapsed < total_timeout: - peers_query = QueryMessage(query=maybe_add_timeout_to_query(select_peers_query, self._metadata_request_timeout), - consistency_level=cl) - local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), - consistency_level=cl) try: remaining = total_timeout - elapsed timeout = min(self._timeout, remaining) if self._timeout is not None else remaining - peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=timeout) + peers_result, local_result, local_address = self._get_schema_agreement_results( + connection, timeout) except OperationTimedOut as timeout: log.debug("[control connection] Timed out waiting for " "response during schema agreement check: %s", timeout) @@ -4131,9 +4128,13 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai log.debug("[control connection] Aborting wait for schema match due to shutdown") return None else: - raise + log.debug("[control connection] Connection lost during schema agreement check") + return False + except (ConnectionBusy, ConnectionException, NoConnectionsAvailable) as exc: + log.debug("[control connection] Unable to check schema agreement: %s", exc) + return False - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, local_address) if schema_mismatches is None: return True @@ -4142,9 +4143,150 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", - connection.endpoint, schema_mismatches) + local_address, schema_mismatches) return False + def _get_schema_agreement_results(self, connection, timeout): + if self._is_connection_unusable(connection): + return self._get_schema_agreement_results_from_session(connection, timeout) + + try: + return self._get_schema_agreement_results_from_connection(connection, timeout) + except OperationTimedOut: + raise + except (ConnectionBusy, ConnectionException): + if self._is_shutdown: + raise + self.return_connection(connection) + log.debug("[control connection] Falling back to session connection for schema agreement") + return self._get_schema_agreement_results_from_session(connection, timeout) + + def _get_schema_agreement_results_from_connection(self, connection, timeout): + peers_query, local_query = self._schema_agreement_queries(connection) + peers_result, local_result = connection.wait_for_responses(peers_query, local_query, timeout=timeout) + return peers_result, local_result, connection.endpoint + + def _get_schema_agreement_results_from_session(self, connection, timeout): + endpoint = self._schema_agreement_endpoint(connection) + deadline = None if timeout is None else time.time() + timeout + + pool, borrowed_connection, request_id = self._borrow_schema_agreement_connection(endpoint, timeout) + peers_query, _ = self._schema_agreement_queries(borrowed_connection) + local_address = borrowed_connection.endpoint + peers_result = self._wait_for_borrowed_schema_response( + pool, borrowed_connection, request_id, peers_query, self._remaining_timeout(deadline)) + + pool, borrowed_connection, request_id = self._borrow_schema_agreement_connection( + endpoint, self._remaining_timeout(deadline)) + _, local_query = self._schema_agreement_queries(borrowed_connection) + local_result = self._wait_for_borrowed_schema_response( + pool, borrowed_connection, request_id, local_query, self._remaining_timeout(deadline)) + + return peers_result, local_result, local_address + + def _schema_agreement_queries(self, connection): + cl = ConsistencyLevel.ONE + select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) + peers_query = QueryMessage(query=maybe_add_timeout_to_query(select_peers_query, self._metadata_request_timeout), + consistency_level=cl) + local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), + consistency_level=cl) + return peers_query, local_query + + def _remaining_timeout(self, deadline): + if deadline is None: + return None + return max(0, deadline - time.time()) + + def _borrow_schema_agreement_connection(self, endpoint, timeout): + last_error = None + for pool in self._schema_agreement_pools(endpoint): + try: + return (pool,) + pool.borrow_connection(timeout=timeout) + except (NoConnectionsAvailable, ConnectionBusy, ConnectionException) as exc: + last_error = exc + log.debug("[control connection] Unable to borrow connection for schema agreement: %s", exc) + + if last_error: + raise last_error + raise NoConnectionsAvailable("No session connection available for schema agreement on %s" % (endpoint,)) + + def _schema_agreement_pools(self, endpoint): + if endpoint is None: + return + + host = self._cluster.metadata.get_host(endpoint) + for session in tuple(getattr(self._cluster, 'sessions', ())): + pools = getattr(session, '_pools', {}) + pool = pools.get(host) if host else None + if pool is None: + for pool_host, candidate_pool in tuple(pools.items()): + if getattr(pool_host, 'endpoint', None) == endpoint: + pool = candidate_pool + break + if pool and not pool.is_shutdown: + yield pool + + def _wait_for_borrowed_schema_response(self, pool, connection, request_id, message, timeout): + event = Event() + responses = [] + + def callback(response): + responses.append(response) + event.set() + + sent = False + orphaned = False + try: + connection.send_msg(message, request_id, callback) + sent = True + if not event.wait(timeout): + orphaned = self._orphan_borrowed_request(pool, connection, request_id) + raise OperationTimedOut(timeout=timeout, in_flight=getattr(connection, 'in_flight', None)) + + response = responses[0] + if isinstance(response, Exception): + if hasattr(response, 'to_exception'): + response = response.to_exception() + raise response + return response + except Exception: + if not sent: + self._return_request_id_if_unused(connection, request_id) + raise + finally: + if not orphaned: + pool.return_connection(connection) + + def _orphan_borrowed_request(self, pool, connection, request_id): + try: + connection._requests.pop(request_id) + except KeyError: + return False + + with connection.lock: + connection.orphaned_request_ids.add(request_id) + if len(connection.orphaned_request_ids) >= connection.orphaned_threshold: + connection.orphaned_threshold_reached = True + + pool.return_connection(connection, stream_was_orphaned=True) + return True + + def _return_request_id_if_unused(self, connection, request_id): + with connection.lock: + if request_id in connection._requests or request_id in connection.orphaned_request_ids: + return + if request_id not in connection.request_ids: + connection.request_ids.append(request_id) + + def _schema_agreement_endpoint(self, connection): + return getattr(connection, 'endpoint', None) or self._last_connection_endpoint + + def _is_connection_unusable(self, connection): + return (connection is None or + getattr(connection, 'is_closed', False) or + getattr(connection, 'is_defunct', False)) + def _get_schema_mismatches(self, peers_result, local_result, local_address): peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) @@ -4265,7 +4407,8 @@ def get_connections(self): return [c] if c else [] def return_connection(self, connection): - if connection is self._connection and (connection.is_defunct or connection.is_closed): + if (connection is self._connection and + (getattr(connection, 'is_defunct', False) or getattr(connection, 'is_closed', False))): self.reconnect() diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..edeff78bc5 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -14,14 +14,16 @@ import unittest +from collections import deque from concurrent.futures import ThreadPoolExecutor +from threading import Lock from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionShutdown from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -106,6 +108,7 @@ class MockCluster(object): def __init__(self): self.metadata = MockMetadata() self.added_hosts = [] + self.sessions = [] self.scheduler = Mock(spec=_Scheduler) self.executor = Mock(spec=ThreadPoolExecutor) self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(RoundRobinPolicy()) @@ -146,6 +149,7 @@ def _node_meta_results(local_results, peer_results): class MockConnection(object): is_defunct = False + is_closed = False def __init__(self): self.endpoint = DefaultEndPoint("192.168.1.0") @@ -169,6 +173,29 @@ def __init__(self): self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results)) +class MockBorrowedConnection(MockConnection): + + is_closed = False + + def __init__(self, responses): + super(MockBorrowedConnection, self).__init__() + self.responses = deque(responses) + self.lock = Lock() + self.in_flight = 0 + self._requests = {} + self.request_ids = deque() + self.orphaned_request_ids = set() + self.orphaned_threshold = 100 + self.orphaned_threshold_reached = False + self.send_msg = Mock(side_effect=self._send_msg) + + def _send_msg(self, message, request_id, callback): + self._requests[request_id] = callback + callback(self.responses.popleft()) + self._requests.pop(request_id, None) + self.request_ids.append(request_id) + + class FakeTime(object): def __init__(self): @@ -301,6 +328,50 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_wait_for_schema_agreement_falls_back_to_session_connection(self): + """ + If the control connection is broken, use a session connection from the + same host for schema agreement. + """ + self.connection.wait_for_responses.side_effect = ConnectionShutdown("closed") + host = self.cluster.metadata.get_host(self.connection.endpoint) + borrowed_connection = MockBorrowedConnection( + _node_meta_results(self.connection.local_results, self.connection.peer_results)) + + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.side_effect = [(borrowed_connection, 1), (borrowed_connection, 2)] + session = Mock() + session._pools = {host: pool} + self.cluster.sessions = [session] + + assert self.control_connection.wait_for_schema_agreement() + assert pool.borrow_connection.call_count == 2 + assert pool.return_connection.call_count == 2 + + def test_wait_for_schema_agreement_falls_back_when_control_connection_is_gone(self): + self.control_connection._last_connection_endpoint = self.connection.endpoint + self.control_connection._connection = None + host = self.cluster.metadata.get_host(self.connection.endpoint) + borrowed_connection = MockBorrowedConnection( + _node_meta_results(self.connection.local_results, self.connection.peer_results)) + + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.side_effect = [(borrowed_connection, 1), (borrowed_connection, 2)] + session = Mock() + session._pools = {host: pool} + self.cluster.sessions = [session] + + assert self.control_connection.wait_for_schema_agreement() + assert pool.borrow_connection.call_count == 2 + assert pool.return_connection.call_count == 2 + + def test_wait_for_schema_agreement_tolerates_missing_connection(self): + self.control_connection._connection = None + + assert not self.control_connection.wait_for_schema_agreement() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata From 9e21370366f9dcd759d449335db9d42c1f42bc25 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 22:06:53 -0400 Subject: [PATCH 2/2] token-map: rebuild on zero-token transitions --- cassandra/cluster.py | 39 ++++++++++++++++++--- cassandra/policies.py | 5 +++ tests/unit/test_control_connection.py | 50 +++++++++++++++++++++++++-- 3 files changed, 88 insertions(+), 6 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 093c9cd884..036243cc17 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3919,8 +3919,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host.dse_workloads = row.get("workloads") tokens = row.get("tokens", None) - if partitioner and tokens and self._token_meta_enabled: - token_map[host] = tokens + if partitioner and self._token_meta_enabled: + should_rebuild_token_map |= self._tokens_changed(host, tokens) + if tokens: + token_map[host] = tokens self._cluster.metadata.update_host(host, old_endpoint=endpoint) for old_host_id, old_host in self._cluster.metadata.all_hosts_items(): @@ -3965,12 +3967,41 @@ def _is_valid_peer(row): if "tokens" in row and not row.get("tokens"): log.debug( - "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." % + "Found a zero-token node (broadcast_rpc: %s, host_id: %s). " + "Keeping host for load balancing, but omitting it from the token map." % (broadcast_rpc, host_id)) - return False return True + def _tokens_changed(self, host, tokens): + current_token_map = self._cluster.metadata.token_map + if current_token_map is None: + return False + + if isinstance(current_token_map, dict): + current_tokens = current_token_map.get(host) + return set(current_tokens or ()) != set(tokens or ()) + + token_to_host_owner = getattr(current_token_map, 'token_to_host_owner', None) + token_class = getattr(current_token_map, 'token_class', None) + if token_to_host_owner is None or token_class is None: + return False + + current_tokens = set( + token for token, owner in token_to_host_owner.items() + if owner == host) + if not tokens: + return bool(current_tokens) + + try: + refreshed_tokens = set(token_class.from_string(token) for token in tokens) + except Exception: + log.debug("[control connection] Unable to compare refreshed tokens for %s", + host, exc_info=True) + return True + + return current_tokens != refreshed_tokens + def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..44a1c681b8 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -464,6 +464,11 @@ class TokenAwarePolicy(LoadBalancingPolicy): If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. + + Token awareness only applies to token-owning nodes. Zero-token CQL + proxy or front-layer nodes are not replicas. When those nodes are the + intended query front layer, configure a round-robin policy such as + :class:`.RoundRobinPolicy` or :class:`.DCAwareRoundRobinPolicy` directly. """ _child_policy = None diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index edeff78bc5..9f6a7ee521 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -406,7 +406,6 @@ def refresh_and_validate_added_hosts(): [None, None, "a", "dc1", "rack1", ["1", "101", "201"], 'uuid1'], ["192.168.1.7", "10.0.0.1", "a", None, "rack1", ["1", "101", "201"], 'uuid2'], ["192.168.1.6", "10.0.0.1", "a", "dc1", None, ["1", "101", "201"], 'uuid3'], - ["192.168.1.5", "10.0.0.1", "a", "dc1", "rack1", None, 'uuid4'], ["192.168.1.4", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], None]]]) refresh_and_validate_added_hosts() @@ -420,10 +419,57 @@ def refresh_and_validate_added_hosts(): [None, 9042, None, 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", None, "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", None, ["2", "102", "202"], "uuid2"], - ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", None, "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], None]]]) refresh_and_validate_added_hosts() + def test_refresh_nodes_and_tokens_keeps_zero_token_peer_for_load_balancing(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", None, "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self.cluster.metadata.get_host_by_host_id("uuid4") + assert zero_token_host is not None + assert zero_token_host in self.cluster.metadata.all_hosts() + assert zero_token_host not in self.cluster.metadata.token_map + + def test_refresh_nodes_and_tokens_keeps_zero_token_local_host_for_load_balancing(self): + self.connection.local_results[1][0][7] = None + + self.control_connection.refresh_node_list_and_token_map() + + local_host = self.cluster.metadata.get_host_by_host_id("uuid1") + assert local_host is not None + assert local_host in self.cluster.metadata.all_hosts() + assert local_host not in self.cluster.metadata.token_map + + def test_refresh_nodes_and_tokens_rebuilds_token_map_when_existing_host_loses_tokens(self): + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self.cluster.metadata.get_host_by_host_id("uuid2") + assert zero_token_host in self.cluster.metadata.token_map + + self.connection.peer_results[1][0][5] = None + self.control_connection.refresh_node_list_and_token_map() + + assert zero_token_host in self.cluster.metadata.all_hosts() + assert zero_token_host not in self.cluster.metadata.token_map + + def test_refresh_nodes_and_tokens_rebuilds_token_map_when_existing_host_gains_tokens(self): + self.connection.peer_results[1][0][5] = None + self.control_connection.refresh_node_list_and_token_map() + + token_host = self.cluster.metadata.get_host_by_host_id("uuid2") + assert token_host in self.cluster.metadata.all_hosts() + assert token_host not in self.cluster.metadata.token_map + + self.connection.peer_results[1][0][5] = ["1", "101", "201"] + self.control_connection.refresh_node_list_and_token_map() + + assert token_host in self.cluster.metadata.token_map + def test_change_ip(self): """ Tests node IPs are updated while the nodes themselves are not