diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8da9df6a55..40daee5cdb 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1916,8 +1916,9 @@ def on_up(self, host): log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: - if host._currently_handling_node_up: - log.debug("Another thread is already handling up status of node %s", host) + if (host._currently_handling_node_up or + getattr(host, "_currently_handling_node_addition", False)): + log.debug("Another thread is already handling up/add status of node %s", host) return if host.is_up: @@ -1958,8 +1959,10 @@ def on_up(self, host): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + + for future in tuple(futures): + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: @@ -2050,69 +2053,98 @@ def on_add(self, host, refresh_nodes=True): log.debug("Handling new host %r and notifying listeners", host) - self.profile_manager.on_add(host) - self.control_connection.on_add(host, refresh_nodes) + # Keep refresh-time pool rebuilds from racing this host's pool creation. + with host.lock: + if getattr(host, "_currently_handling_node_addition", False): + log.debug("Another thread is already handling add status of node %s", host) + return + host._currently_handling_node_addition = True - distance = self.profile_manager.distance(host) - if distance != HostDistance.IGNORED: - self._prepare_all_queries(host) - log.debug("Done preparing queries for new host %r", host) + have_future = False + add_aborted = False + futures = set() + try: + self.profile_manager.on_add(host) + self.control_connection.on_add(host, refresh_nodes) - if distance == HostDistance.IGNORED: - log.debug("Not adding connection pool for new host %r because the " - "load balancing policy has marked it as IGNORED", host) - self._finalize_add(host, set_up=False) - return + distance = self.profile_manager.distance(host) + if distance != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing queries for new host %r", host) - futures_lock = Lock() - futures_results = [] - futures = set() + if distance == HostDistance.IGNORED: + log.debug("Not adding connection pool for new host %r because the " + "load balancing policy has marked it as IGNORED", host) + self._finalize_add(host, set_up=False) + return - def future_completed(future): - with futures_lock: - futures.discard(future) + futures_lock = Lock() + futures_results = [] - try: - futures_results.append(future.result()) - except Exception as exc: - futures_results.append(exc) + def future_completed(future): + with futures_lock: + futures.discard(future) - if futures: - return + if add_aborted: + return - log.debug('All futures have completed for added host %s', host) + try: + futures_results.append(future.result()) + except Exception as exc: + futures_results.append(exc) - for exc in [f for f in futures_results if isinstance(f, Exception)]: - log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) - return + if futures: + return - if not all(futures_results): - log.warning("Connection pool could not be created, not marking node %s up", host) - return + log.debug('All futures have completed for added host %s', host) - self._finalize_add(host) + for exc in [f for f in futures_results if isinstance(f, Exception)]: + log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + with host.lock: + host._currently_handling_node_addition = False + return - have_future = False - for session in tuple(self.sessions): - future = session.add_or_renew_pool(host, is_host_addition=True) - if future is not None: - have_future = True - futures.add(future) + if not all(futures_results): + log.warning("Connection pool could not be created, not marking node %s up", host) + with host.lock: + host._currently_handling_node_addition = False + return + + self._finalize_add(host) + + for session in tuple(self.sessions): + future = session.add_or_renew_pool(host, is_host_addition=True) + if future is not None: + have_future = True + futures.add(future) + + for future in tuple(futures): future.add_done_callback(future_completed) - if not have_future: - self._finalize_add(host) + if not have_future: + self._finalize_add(host) + except Exception: + add_aborted = True + for future in tuple(futures): + future.cancel() + with host.lock: + host._currently_handling_node_addition = False + raise def _finalize_add(self, host, set_up=True): - if set_up: - host.set_up() + try: + if set_up: + host.set_up() - for listener in self.listeners: - listener.on_add(host) + for listener in self.listeners: + listener.on_add(host) - # see if there are any pools to add or remove now that the host is marked up - for session in tuple(self.sessions): - session.update_created_pools() + # see if there are any pools to add or remove now that the host is marked up + for session in tuple(self.sessions): + session.update_created_pools() + finally: + with host.lock: + host._currently_handling_node_addition = False def on_remove(self, host): if self.is_shutdown: @@ -2137,7 +2169,8 @@ def signal_connection_failure(self, host, connection_exc, is_host_addition, expe self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): + def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None, + is_zero_token=None): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. @@ -2147,8 +2180,16 @@ def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_no """ with self.metadata._hosts_lock: if endpoint in self.metadata._host_id_by_endpoint: - return self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]], False - host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id)) + host = self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]] + if is_zero_token is not None: + host.is_zero_token = is_zero_token + return host, False + host = Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id) + if is_zero_token is not None: + host.is_zero_token = is_zero_token + host, new = self.metadata.add_or_return_host(host) + if not new and is_zero_token is not None: + host.is_zero_token = is_zero_token if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) @@ -3315,7 +3356,10 @@ def update_created_pools(self): # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): + # on_up() and on_add() already own pool creation for hosts in flight. + if (distance != HostDistance.IGNORED and host.is_up in (True, None) and + not host._currently_handling_node_up and + not getattr(host, "_currently_handling_node_addition", False)): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed @@ -3864,6 +3908,8 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None + zero_token_status_changed = False + promoted_zero_token_hosts = [] for row in peers_result: if not self._is_valid_peer(row): continue @@ -3884,10 +3930,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") + tokens = row.get("tokens", None) + has_token_status = "tokens" in row + is_zero_token = has_token_status and not tokens + token_status = is_zero_token if has_token_status else None if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: + if has_token_status: + status_changed = self._update_zero_token_info(host, is_zero_token) + zero_token_status_changed |= status_changed + should_rebuild_token_map |= status_changed + if status_changed and not is_zero_token and host.is_up is not True: + promoted_zero_token_hosts.append(host) log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: @@ -3901,11 +3957,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) - host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) + host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, + refresh_nodes=False, host_id=host_id, + is_zero_token=token_status) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + if has_token_status: + status_changed = self._update_zero_token_info(host, is_zero_token) + zero_token_status_changed |= status_changed + should_rebuild_token_map |= status_changed + if status_changed and not is_zero_token and host.is_up is not True: + promoted_zero_token_hosts.append(host) + host.host_id = host_id host.broadcast_address = _NodeInfo.get_broadcast_address(row) host.broadcast_port = _NodeInfo.get_broadcast_port(row) @@ -3916,7 +3981,6 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host.dse_workload = row.get("workload") host.dse_workloads = row.get("workloads") - tokens = row.get("tokens", None) if partitioner and tokens and self._token_meta_enabled: token_map[host] = tokens self._cluster.metadata.update_host(host, old_endpoint=endpoint) @@ -3932,6 +3996,22 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) + for host in promoted_zero_token_hosts: + self._cluster.on_up(host) + + if zero_token_status_changed: + for session in tuple(getattr(self._cluster, "sessions", ())): + session.update_created_pools() + + @staticmethod + def _update_zero_token_info(host, is_zero_token): + is_zero_token = bool(is_zero_token) + if host.is_zero_token == is_zero_token: + return False + + host.is_zero_token = is_zero_token + return True + @staticmethod def _is_valid_peer(row): broadcast_rpc = _NodeInfo.get_broadcast_rpc_address(row) @@ -3963,9 +4043,8 @@ def _is_valid_peer(row): if "tokens" in row and not row.get("tokens"): log.debug( - "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." % + "Found a zero-token node - tokens are empty (broadcast_rpc: %s, host_id: %s). Adding host without tokens." % (broadcast_rpc, host_id)) - return False return True diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..72f0332a1f 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -121,10 +121,39 @@ class LoadBalancingPolicy(HostStateListener): """ _hosts_lock = None + _ignore_zero_token_hosts = True def __init__(self): self._hosts_lock = Lock() + def _is_ignored_zero_token_host(self, host): + if getattr(host, 'is_zero_token', False) is not True: + return False + + child_policy = getattr(self, '_child_policy', None) + if child_policy is not None: + # Preserve child opt-outs through wrapper layers. + child_filter = getattr(child_policy, '_is_ignored_zero_token_host', None) + if child_filter is not None: + return bool(child_filter(host)) + return bool(getattr(child_policy, '_ignore_zero_token_hosts', self._ignore_zero_token_hosts)) + + return bool(self._ignore_zero_token_hosts) + + def _filter_zero_token_hosts(self, hosts): + return tuple(h for h in hosts if not self._is_ignored_zero_token_host(h)) + + def _is_in_flight_host(self, host): + return ((getattr(host, '_currently_handling_node_up', False) is True or + getattr(host, '_currently_handling_node_addition', False) is True) and + getattr(host, 'is_up', None) is not True) + + def _is_ignored_query_plan_host(self, host): + return self._is_ignored_zero_token_host(host) or self._is_in_flight_host(host) + + def _filter_query_plan_hosts(self, hosts): + return tuple(h for h in hosts if not self._is_ignored_query_plan_host(h)) + def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in @@ -178,10 +207,13 @@ class RoundRobinPolicy(LoadBalancingPolicy): def populate(self, cluster, hosts): self._live_hosts = frozenset(hosts) - if len(hosts) > 1: - self._position = randint(0, len(hosts) - 1) + live_hosts = self._filter_zero_token_hosts(hosts) + if len(live_hosts) > 1: + self._position = randint(0, len(live_hosts) - 1) def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return HostDistance.LOCAL def make_query_plan(self, working_keyspace=None, query=None): @@ -190,7 +222,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - hosts = self._live_hosts + hosts = self._filter_query_plan_hosts(self._live_hosts) length = len(hosts) if length: pos %= length @@ -257,6 +289,9 @@ def populate(self, cluster, hosts): self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL @@ -264,7 +299,7 @@ def distance(self, host): if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED else: - dc_hosts = self._dc_live_hosts.get(dc) + dc_hosts = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) if not dc_hosts: return HostDistance.IGNORED @@ -279,7 +314,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_live = self._dc_live_hosts.get(self.local_dc, ()) + local_live = self._filter_query_plan_hosts(self._dc_live_hosts.get(self.local_dc, ())) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -287,7 +322,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] for dc in other_dcs: - remote_live = self._dc_live_hosts.get(dc, ()) + remote_live = self._filter_query_plan_hosts(self._dc_live_hosts.get(dc, ())) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -372,6 +407,9 @@ def populate(self, cluster, hosts): self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + rack = self._rack(host) dc = self._dc(host) if rack == self.local_rack and dc == self.local_dc: @@ -383,7 +421,7 @@ def distance(self, host): if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED - dc_hosts = self._dc_live_hosts.get(dc, ()) + dc_hosts = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) if not dc_hosts: return HostDistance.IGNORED if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc: @@ -395,14 +433,15 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + local_rack_live = self._filter_query_plan_hosts(self._live_hosts.get((self.local_dc, self.local_rack), ())) pos = (pos % len(local_rack_live)) if local_rack_live else 0 # Slice the cyclic iterator to start from pos and include the next len(local_live) elements # This ensures we get exactly one full cycle starting from pos for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): yield host - local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] + local_live = [host for host in self._filter_query_plan_hosts(self._dc_live_hosts.get(self.local_dc, ())) + if host.rack != self.local_rack] pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -410,6 +449,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy for dc, remote_live in self._dc_live_hosts.copy().items(): if dc != self.local_dc: + remote_live = self._filter_query_plan_hosts(remote_live) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -491,6 +531,9 @@ def check_supported(self): (self.__class__.__name__, self._cluster_metadata.partitioner)) def distance(self, *args, **kwargs): + host = args[0] if args else kwargs.get('host') + if host is not None and self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return self._child_policy.distance(*args, **kwargs) def make_query_plan(self, working_keyspace=None, query=None): @@ -499,7 +542,8 @@ def make_query_plan(self, working_keyspace=None, query=None): child = self._child_policy if query is None or query.routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): - yield host + if not self._is_ignored_query_plan_host(host): + yield host return replicas = [] @@ -520,13 +564,15 @@ def make_query_plan(self, working_keyspace=None, query=None): def yield_in_order(hosts): for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: for replica in hosts: - if replica.is_up and child.distance(replica) == distance: + if (not self._is_ignored_query_plan_host(replica) and + replica.is_up and child.distance(replica) == distance): yield replica # yield replicas: local_rack, local, remote yield from yield_in_order(replicas) # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) + if host not in replicas and not self._is_ignored_query_plan_host(host)]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) @@ -554,6 +600,8 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): attempts are made to private IP addresses remotely """ + _ignore_zero_token_hosts = False + def __init__(self, hosts): """ The `hosts` parameter should be a sequence of hosts to permit @@ -674,6 +722,9 @@ def distance(self, host): :attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy otherwise. """ + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + if self.predicate(host): return self._child_policy.distance(host) else: @@ -699,7 +750,7 @@ def make_query_plan(self, working_keyspace=None, query=None): working_keyspace=working_keyspace, query=query ) for host in child_qp: - if self.predicate(host): + if not self._is_ignored_query_plan_host(host) and self.predicate(host): yield host def check_supported(self): @@ -1305,6 +1356,9 @@ def __init__(self, child_policy): self._child_policy = child_policy def distance(self, *args, **kwargs): + host = args[0] if args else kwargs.get('host') + if host is not None and self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return self._child_policy.distance(*args, **kwargs) def populate(self, cluster, hosts): @@ -1347,14 +1401,15 @@ def make_query_plan(self, working_keyspace=None, query=None): target_host = self._cluster_metadata.get_host(addr) child = self._child_policy - if target_host and target_host.is_up: + if target_host and target_host.is_up and not self._is_ignored_query_plan_host(target_host): yield target_host for h in child.make_query_plan(keyspace, query): - if h != target_host: + if h != target_host and not self._is_ignored_query_plan_host(h): yield h else: for h in child.make_query_plan(keyspace, query): - yield h + if not self._is_ignored_query_plan_host(h): + yield h # TODO for backward compatibility, remove in next major diff --git a/cassandra/pool.py b/cassandra/pool.py index 2da657256f..d19844917e 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -127,6 +127,11 @@ class Host(object): up or down. """ + is_zero_token = False + """ + :const:`True` if the node has no tokens in the system topology tables. + """ + release_version = None """ release_version as queried from the control connection system tables @@ -179,6 +184,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) + self.is_zero_token = False self.lock = RLock() @property @@ -927,4 +933,3 @@ def open_count(self): def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 49208ac53e..39df2a8bc4 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +from concurrent.futures import Future import logging import socket +import unittest +import uuid from unittest.mock import patch, Mock -import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion @@ -101,6 +101,134 @@ def test_tuple_for_contact_points(self): assert cp.address == '127.0.0.3' assert cp.port == 9999 + def test_on_add_clears_in_progress_flag_when_later_session_add_fails(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + successful_session = Mock() + successful_session.add_or_renew_pool.return_value = Future() + successful_session.update_created_pools.return_value = set() + failing_session = Mock() + failing_session.add_or_renew_pool.side_effect = RuntimeError("pool add failed") + cluster.sessions = [successful_session, failing_session] + + try: + with pytest.raises(RuntimeError): + cluster.on_add(host, refresh_nodes=False) + + assert not host._currently_handling_node_addition + + with pytest.raises(RuntimeError): + cluster.on_add(host, refresh_nodes=False) + + assert successful_session.add_or_renew_pool.call_count == 2 + finally: + cluster.shutdown() + + def test_on_add_waits_for_all_session_pool_futures_before_marking_host_up(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + completed_future = Future() + completed_future.set_result(True) + pending_future = Future() + first_session = Mock() + first_session.add_or_renew_pool.return_value = completed_future + second_session = Mock() + second_session.add_or_renew_pool.return_value = pending_future + listener = Mock() + cluster.sessions = [first_session, second_session] + cluster.register_listener(listener) + + try: + cluster.on_add(host, refresh_nodes=False) + + assert host.is_up is not True + listener.on_add.assert_not_called() + first_session.update_created_pools.assert_not_called() + second_session.update_created_pools.assert_not_called() + + pending_future.set_result(True) + + assert host.is_up is True + listener.on_add.assert_called_once_with(host) + first_session.update_created_pools.assert_called_once_with() + second_session.update_created_pools.assert_called_once_with() + finally: + cluster.shutdown() + + def test_on_add_excludes_host_from_query_plan_until_pool_futures_complete(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, datacenter="dc1", rack="rack1", host_id=uuid.uuid4()) + pending_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pending_future + session.update_created_pools.return_value = set() + cluster.sessions = [session] + + try: + cluster.on_add(host, refresh_nodes=False) + + load_balancer = cluster.profile_manager.default.load_balancing_policy + assert host not in list(load_balancer.make_query_plan()) + + pending_future.set_result(True) + + assert list(load_balancer.make_query_plan()) == [host] + finally: + cluster.shutdown() + + def test_on_up_waits_for_all_session_pool_futures_before_marking_host_up(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + completed_future = Future() + completed_future.set_result(True) + pending_future = Future() + first_session = Mock() + first_session.add_or_renew_pool.return_value = completed_future + second_session = Mock() + second_session.add_or_renew_pool.return_value = pending_future + listener = Mock() + cluster.sessions = [first_session, second_session] + cluster.register_listener(listener) + + try: + cluster.on_up(host) + + assert host.is_up is not True + listener.on_up.assert_not_called() + first_session.update_created_pools.assert_not_called() + second_session.update_created_pools.assert_not_called() + + pending_future.set_result(True) + + assert host.is_up is True + listener.on_up.assert_called_once_with(host) + first_session.update_created_pools.assert_called_once_with() + second_session.update_created_pools.assert_called_once_with() + finally: + cluster.shutdown() + + def test_on_up_excludes_host_from_query_plan_until_pool_futures_complete(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, datacenter="dc1", rack="rack1", host_id=uuid.uuid4()) + host.set_down() + pending_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pending_future + session.update_created_pools.return_value = set() + cluster.sessions = [session] + + try: + cluster.on_up(host) + + load_balancer = cluster.profile_manager.default.load_balancing_policy + assert host not in list(load_balancer.make_query_plan()) + + pending_future.set_result(True) + + assert list(load_balancer.make_query_plan()) == [host] + finally: + cluster.shutdown() + def test_invalid_contact_point_types(self): with pytest.raises(ValueError): Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index d759e12332..c5a92138f8 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -80,8 +80,8 @@ def add_or_return_host(self, host): def update_host(self, host, old_endpoint): host, created = self.add_or_return_host(host) - self._host_id_by_endpoint[host.endpoint] = host.host_id self._host_id_by_endpoint.pop(old_endpoint, False) + self._host_id_by_endpoint[host.endpoint] = host.host_id def all_hosts_items(self): return list(self.hosts.items()) @@ -112,8 +112,10 @@ def __init__(self): self.endpoint_factory = DefaultEndPointFactory().configure(self) self.ssl_options = None - def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None): + def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None, + is_zero_token=False): host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack, host_id=host_id) + host.is_zero_token = is_zero_token host, _ = self.metadata.add_or_return_host(host) self.added_hosts.append(host) return host, True @@ -206,6 +208,16 @@ def setUp(self): self.control_connection._connection = self.connection self.control_connection._time = self.time + def _assert_zero_token_host_without_token_map_entry(self, endpoint, host_id): + zero_token_host = self.cluster.metadata.get_host(endpoint) + assert zero_token_host is not None + assert zero_token_host.host_id == host_id + assert zero_token_host.datacenter == "dc1" + assert zero_token_host.rack == "rack1" + assert zero_token_host.is_zero_token + assert zero_token_host not in self.cluster.metadata.token_map + return zero_token_host + def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing @@ -321,7 +333,6 @@ def refresh_and_validate_added_hosts(): [None, None, "a", "dc1", "rack1", ["1", "101", "201"], 'uuid1'], ["192.168.1.7", "10.0.0.1", "a", None, "rack1", ["1", "101", "201"], 'uuid2'], ["192.168.1.6", "10.0.0.1", "a", "dc1", None, ["1", "101", "201"], 'uuid3'], - ["192.168.1.5", "10.0.0.1", "a", "dc1", "rack1", None, 'uuid4'], ["192.168.1.4", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], None]]]) refresh_and_validate_added_hosts() @@ -335,7 +346,6 @@ def refresh_and_validate_added_hosts(): [None, 9042, None, 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", None, "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", None, ["2", "102", "202"], "uuid2"], - ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", None, "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], None]]]) refresh_and_validate_added_hosts() @@ -411,6 +421,61 @@ def test_refresh_nodes_and_tokens_add_host(self): assert self.cluster.added_hosts[0].rack == "rack1" assert self.cluster.added_hosts[0].host_id == "uuid4" + def test_refresh_nodes_and_tokens_adds_zero_token_host_without_token_map_entry(self): + # Zero-token nodes are valid topology members, but they do not own token ranges. + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", None, "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + assert [] == self.cluster.metadata.removed_hosts + + def test_refresh_nodes_and_tokens_adds_empty_token_host_without_token_map_entry(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", [], "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + + def test_refresh_nodes_and_tokens_keeps_zero_token_local_host_without_token_map_entry(self): + self.connection.local_results[1][0][7] = None + + self.control_connection.refresh_node_list_and_token_map() + + self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.0"), "uuid1") + assert [] == self.cluster.added_hosts + assert [] == self.cluster.metadata.removed_hosts + + def test_refresh_nodes_and_tokens_updates_zero_token_status_when_tokens_change(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", None, "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + + self.connection.peer_results[1][-1][5] = ["3", "103", "203"] + self.control_connection.refresh_node_list_and_token_map() + + assert not zero_token_host.is_zero_token + assert zero_token_host in self.cluster.metadata.token_map + assert self.cluster.metadata.token_map[zero_token_host] == ["3", "103", "203"] + def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] self.control_connection.refresh_node_list_and_token_map() @@ -589,6 +654,29 @@ def test_refresh_nodes_and_tokens_add_host_detects_port(self): assert self.cluster.added_hosts[0].datacenter == "dc1" assert self.cluster.added_hosts[0].rack == "rack1" + def test_refresh_nodes_and_tokens_adds_zero_token_host_from_peers_v2_without_token_map_entry(self): + del self.connection.peer_results[:] + self.connection.peer_results.extend(self.connection.peer_results_v2) + self.connection.peer_results[1].append( + ["192.168.1.3", 555, "10.0.0.3", 666, "a", "dc1", "rack1", None, "uuid4"] + ) + self.connection.wait_for_responses = Mock(return_value=_node_meta_results( + self.connection.local_results, self.connection.peer_results)) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3", 555), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + assert zero_token_host.endpoint.port == 555 + assert zero_token_host.broadcast_rpc_address == "192.168.1.3" + assert zero_token_host.broadcast_rpc_port == 555 + assert zero_token_host.broadcast_address == "10.0.0.3" + assert zero_token_host.broadcast_port == 666 + assert [] == self.cluster.metadata.removed_hosts + def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): del self.connection.peer_results[:] self.connection.peer_results.extend(self.connection.peer_results_v2) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..37577fbe2a 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -33,13 +33,22 @@ RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy, + DefaultLoadBalancingPolicy) from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement from cassandra.tablets import Tablets, Tablet +def make_host(address, datacenter="dc1", rack="rack1", is_zero_token=False): + host = Host(DefaultEndPoint(address), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_location_info(datacenter, rack) + host.set_up() + host.is_zero_token = is_zero_token + return host + + class LoadBalancingPolicyTest(unittest.TestCase): def test_non_implemented(self): """ @@ -187,6 +196,26 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) assert qplan == [] + +@pytest.mark.parametrize("policy", [ + RoundRobinPolicy(), + DCAwareRoundRobinPolicy("dc1"), + RackAwareRoundRobinPolicy("dc1", "rack1"), +]) +def test_zero_token_hosts_ignored_by_round_robin_policies(policy): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + policy.populate(Mock(), [host, zero_token_host]) + + assert list(policy.make_query_plan()) == [host] + assert policy.distance(zero_token_host) == HostDistance.IGNORED + + zero_token_host.is_zero_token = False + + assert set(policy.make_query_plan()) == {host, zero_token_host} + assert policy.distance(zero_token_host) != HostDistance.IGNORED + @pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) class TestRackOrDCAwareRoundRobinPolicy: @@ -850,6 +879,61 @@ def test_statement_keyspace(self): assert replicas + hosts[:2] == qplan cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + def test_ignores_zero_token_hosts(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.get_replicas.return_value = [zero_token_host, host] + + child_policy = Mock() + child_policy.make_query_plan.return_value = [zero_token_host, host] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, [host, zero_token_host]) + + query = Statement(routing_key=b"routing_key", keyspace="keyspace_name") + + assert list(policy.make_query_plan(None, query)) == [host] + + def test_ignores_zero_token_hosts_with_legacy_child_policy(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + class LegacyChildPolicy(object): + def populate(self, cluster, hosts): + pass + + def distance(self, host): + return HostDistance.LOCAL + + def make_query_plan(self, working_keyspace=None, query=None): + return [zero_token_host, host] + + def on_up(self, host): + pass + + def on_down(self, host): + pass + + def on_add(self, host): + pass + + def on_remove(self, host): + pass + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + + policy = TokenAwarePolicy(LegacyChildPolicy()) + policy.populate(cluster, [host, zero_token_host]) + + assert list(policy.make_query_plan()) == [host] + def test_shuffles_if_given_keyspace_and_routing_key(self): """ Test to validate the hosts are shuffled when `shuffle_replicas` is truthy @@ -1432,6 +1516,14 @@ def test_hosts_with_socket_hostname(self): assert policy.distance(host) == HostDistance.LOCAL + def test_ignores_zero_token_status(self): + policy = WhiteListRoundRobinPolicy(["127.0.0.1"]) + host = make_host("127.0.0.1", is_zero_token=True) + policy.populate(None, [host]) + + assert list(policy.make_query_plan()) == [host] + assert policy.distance(host) == HostDistance.LOCAL + class AddressTranslatorTest(unittest.TestCase): @@ -1567,6 +1659,12 @@ def test_accepted_filter_defers_to_child_policy(self): # second call of _child_policy with count() side effect assert self.hfp.distance(self.accepted_host) == distances[1] + def test_zero_token_host_is_ignored_before_child_policy(self): + host = make_host("acceptme", is_zero_token=True) + + assert self.hfp.distance(host) == HostDistance.IGNORED + assert self.hfp._child_policy.distance.call_count == 0 + class HostFilterPolicyPopulateTest(unittest.TestCase): @@ -1618,6 +1716,30 @@ def test_query_plan_deferred_to_child(self): ) assert qp == hfp._child_policy.make_query_plan.return_value + def test_zero_token_hosts_are_filtered(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy', make_query_plan=Mock(return_value=[zero_token_host, host])), + predicate=lambda host: True + ) + + assert list(hfp.make_query_plan()) == [host] + + def test_default_policy_filters_zero_token_target_and_child_hosts(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata.get_host.return_value = zero_token_host + + child_policy = Mock(name='child_policy', make_query_plan=Mock(return_value=[zero_token_host, host])) + policy = DefaultLoadBalancingPolicy(child_policy) + policy.populate(cluster, [host, zero_token_host]) + query = Mock(target_host=zero_token_host.address, keyspace=None) + + assert list(policy.make_query_plan(query=query)) == [host] + def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)]