diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..483843c2a6 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -20,16 +20,17 @@ import atexit import datetime +from enum import Enum from binascii import hexlify from collections import defaultdict from collections.abc import Mapping -from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from concurrent.futures import Future, ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, reduce, wraps from itertools import groupby, count, chain import json import logging -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Tuple from warnings import warn from random import random import re @@ -214,6 +215,14 @@ def __init__(self, message, errors): self.errors = errors +class SchemaAgreementScope(str, Enum): + """Scope selectors for :meth:`.Session.wait_for_schema_agreement`.""" + + RACK = 'rack' + DC = 'dc' + CLUSTER = 'cluster' + + def _future_completed(future): """ Helper for run_in_executor() """ exc = future.exception() @@ -3374,6 +3383,185 @@ def pool_finished_setting_keyspace(pool, host_errors): for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + def wait_for_schema_agreement(self, wait_time: Optional[float] = None, + scope: SchemaAgreementScope = SchemaAgreementScope.CLUSTER) -> bool: + """ + Wait for connected hosts in the selected scope to report the same + schema version from ``system.local``. + + By default, the timeout for this operation is governed by + :attr:`~.Cluster.max_schema_agreement_wait` and + :attr:`~.Cluster.control_connection_timeout`. + + Passing ``wait_time`` here overrides + :attr:`~.Cluster.max_schema_agreement_wait`. If provided, ``wait_time`` + must be greater than 0. + + ``scope`` determines which connected hosts participate in the check. + Pass :attr:`SchemaAgreementScope.RACK`, :attr:`SchemaAgreementScope.DC`, + or :attr:`SchemaAgreementScope.CLUSTER`. + The default is :attr:`SchemaAgreementScope.CLUSTER`. ``RACK`` narrows + the check to connected hosts in the local rack only. ``DC`` checks + connected hosts in the local datacenter. ``CLUSTER`` queries every + connected host across all datacenters. + + :param wait_time: Override for + :attr:`~.Cluster.max_schema_agreement_wait`, should be positive + number. + :param scope: Restricts the check to connected hosts in the local rack, + local datacenter, or whole connected cluster. + :returns: ``True`` when the selected connected hosts agree on schema, + otherwise ``False``. + :raises ValueError: If ``wait_time`` is provided and is not greater + than 0. + :raises ValueError: If ``scope`` is not one of the schema agreement + scope values. + """ + + if wait_time is not None and wait_time <= 0: + raise ValueError("wait_time must be greater than 0") + + total_timeout = wait_time if wait_time is not None else self.cluster.max_schema_agreement_wait + if total_timeout <= 0: + raise ValueError("total_timeout must be greater than 0") + + deadline = time.time() + total_timeout + schema_mismatches = None + scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else ( + 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster') + + while time.time() < deadline: + schema_mismatches = self._get_schema_mismatches_for_scope(deadline, scope) + if schema_mismatches is None: + return True + + log.debug("[session] Connected hosts in the %s still disagree on schema, trying again", scope_label) + remaining = deadline - time.time() + if remaining > 0: + time.sleep(min(0.2, remaining)) + + log.warning("[session] Connected hosts in the %s are reporting a schema disagreement: %s", + scope_label, schema_mismatches) + return False + + def _get_schema_mismatches_for_scope(self, deadline: float, + scope: SchemaAgreementScope) -> Optional[Dict[Any, Any]]: + hosts = self._get_schema_agreement_hosts(scope) + mismatches = defaultdict(list) + errors = {} + scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else ( + 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster') + + if not hosts: + errors[scope.value] = ConnectionException( + "No connected hosts available in the %s" % (scope_label,) + ) + return {'unavailable': errors} + + metadata_request_timeout = self.cluster.control_connection._metadata_request_timeout + query = maybe_add_timeout_to_query(ControlConnection._SELECT_SCHEMA_LOCAL, metadata_request_timeout) + + schema_version_futures = [] + for host in hosts: + try: + schema_version_future = self._query_local_schema_version(host, query, deadline) + except Exception as exc: + errors[host.endpoint] = exc + continue + + schema_version_futures.append((host, schema_version_future)) + + if schema_version_futures: + # Start all host queries first, then wait for the whole batch. + remaining = max(0.0, deadline - time.time()) + if remaining > 0: + wait_futures([future for _, future in schema_version_futures], timeout=remaining) + + for host, future in schema_version_futures: + if future.done(): + try: + rows = future.result() + except Exception as exc: + errors[host.endpoint] = exc + continue + + row = rows.one() + schema_version = getattr(row, "schema_version", None) if row is not None else None + mismatches[schema_version].append(host.endpoint) + else: + errors[host.endpoint] = OperationTimedOut(last_host=host, timeout=max(0.0, deadline - time.time())) + + if len(mismatches) == 1 and None not in mismatches and not errors: + log.debug("[session] Connected hosts in the %s agree on schema", scope_label) + return None + + if errors: + mismatches['unavailable'] = errors + return dict(mismatches) + + def _get_schema_agreement_hosts(self, scope: SchemaAgreementScope) -> Tuple[Host, ...]: + if scope is SchemaAgreementScope.RACK: + allowed_distances = (HostDistance.LOCAL_RACK,) + elif scope is SchemaAgreementScope.DC: + allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL) + else: + allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE) + + return tuple( + host for host, pool in tuple(self._pools.items()) + if host.is_up + and not pool.is_shutdown + and self._profile_manager.distance(host) in allowed_distances) + + def _query_local_schema_version(self, host: Host, query: str, deadline: float) -> Future: + remaining = max(0.0, deadline - time.time()) + try: + response_future = self.execute_async( + query, + timeout=self._schema_agreement_query_timeout(remaining), + host=host, + ) + except OperationTimedOut as timeout: + log.debug("[session] Timed out waiting for schema version from %s: %s", host, timeout) + raise + except Exception as exc: + log.debug("[session] Error querying schema version from %s: %s", host, exc) + raise + + # execute_async returns cassandra.cluster.ResponseFuture, which does not have bulk waiting logic for it. + # That is why _query_local_schema_version returns concurrent.futures.Future + # so that schema agreement logic could use concurrent.futures.wait_futures to wait on them. + # schema_version_future is an adapter between cassandra.cluster.ResponseFuture and concurrent.futures.Future + # to make things work + schema_version_future = Future() + + def _set_result(result, result_future=schema_version_future, response_future=response_future): + if result_future.done(): + return + try: + result_future.set_result(ResultSet(response_future, result)) + except Exception as exc: + result_future.set_exception(exc) + + def _set_exception(exc, result_future=schema_version_future): + if result_future.done(): + return + result_future.set_exception(exc) + + try: + response_future.add_callbacks(_set_result, _set_exception) + except Exception as exc: + log.debug("[session] Error registering schema version callback from %s: %s", host, exc) + raise + + return schema_version_future + + def _schema_agreement_query_timeout(self, remaining: float) -> float: + control_timeout = self.cluster.control_connection._timeout + if control_timeout is None: + return max(0.0, remaining) + return max(0.0, min(control_timeout, remaining)) + def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new @@ -3786,7 +3974,7 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w if self._cluster.is_shutdown: return False - agreed = self.wait_for_schema_agreement(connection, + agreed = self._wait_for_schema_agreement(connection=connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) @@ -4079,7 +4267,30 @@ def _handle_schema_change(self, event): self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + """ + Wait for schema agreement from the control connection's metadata view. + + This method is intended for internal metadata refresh flows. External + callers should use :meth:`.Session.wait_for_schema_agreement` instead. + + The control connection observes schema agreement from its own + perspective, which may include hosts the session is not using, and it + may fail when the control connection itself is transiently unhealthy. + That can produce false positives or failures that do not reflect + whether a session can safely proceed. + + .. deprecated:: 3.30.0 + Use :meth:`.Session.wait_for_schema_agreement` instead. + """ + warn("ControlConnection.wait_for_schema_agreement is deprecated and will be removed in 4.0. " + "Use Session.wait_for_schema_agreement instead. " + "This method is for internal metadata refresh use only.", + DeprecationWarning, stacklevel=2) + return self._wait_for_schema_agreement(connection=connection, + preloaded_results=preloaded_results, + wait_time=wait_time) + def _wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: return True diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 51f03f3d97..de8518d271 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -169,6 +169,8 @@ Clusters and Sessions .. automethod:: set_keyspace(keyspace) + .. automethod:: wait_for_schema_agreement + .. automethod:: get_execution_profile .. automethod:: execution_profile_clone_update diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index f892acba52..3b4dcd33d5 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -158,4 +158,4 @@ def check_and_wait_for_agreement(self, session, rs, exepected): time.sleep(1) assert rs.response_future.is_schema_agreed == exepected if not rs.response_future.is_schema_agreed: - session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) + session.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index e608a9610b..18f3dfb298 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -147,7 +147,7 @@ def test_can_register_udt_before_connecting(self): c.register_user_type("udt_test_register_before_connecting2", "user", User2) s = c.connect(wait_for_all_pools=True) - c.control_connection.wait_for_schema_agreement() + s.wait_for_schema_agreement() s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0") diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..b6f2da5372 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,14 +15,16 @@ import logging import socket +from types import SimpleNamespace from unittest.mock import patch, Mock import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion -from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ +from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import ConnectionBusy from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -247,11 +249,123 @@ def test_event_delay_timing(self, *_): class SessionTest(unittest.TestCase): + class FakeTime(object): + + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + class MockPool(object): + + def __init__(self, host, connection): + self.host = host + self.host_distance = HostDistance.LOCAL + self.is_shutdown = False + self.connection = connection + + def _get_connection_for_routing_key(self): + return self.connection + + class MockSchemaVersionFuture(object): + + def __init__(self, outcome, auto_complete=True): + self._outcome = outcome + self._auto_complete = auto_complete + self._delivered = False + self._callback_state = None + self._col_names = ("schema_version",) + self._col_types = None + self.has_more_pages = False + self._continuous_paging_session = None + + def _deliver(self): + if self._delivered or self._callback_state is None: + return + + self._delivered = True + callback, errback, callback_args, callback_kwargs, errback_args, errback_kwargs = self._callback_state + if isinstance(self._outcome, Exception): + errback(self._outcome, *errback_args, **errback_kwargs) + else: + row = SimpleNamespace(schema_version=self._outcome) + callback([row], *callback_args, **callback_kwargs) + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + self._callback_state = ( + callback, + errback, + callback_args, + callback_kwargs or {}, + errback_args, + errback_kwargs or {}, + ) + if self._auto_complete: + self._deliver() + return self + + def complete(self): + self._deliver() + + def result(self): + if isinstance(self._outcome, Exception): + raise self._outcome + return ResultSet(self, [SimpleNamespace(schema_version=self._outcome)]) + def setUp(self): if connection_class is None: raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + def _mock_schema_future(self, outcome): + return self.MockSchemaVersionFuture(outcome) + + def _host_query_count(self, session, target_host): + return sum(1 for call in session.execute_async.call_args_list if call.kwargs.get('host') is target_host) + + def _new_schema_agreement_session(self, schema_versions, distances=None): + hosts = [] + connections = {} + distance_map = {} + if distances is None: + distances = [HostDistance.LOCAL] * len(schema_versions) + + for index, schema_version in enumerate(schema_versions): + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + hosts.append(host) + distance_map[host] = distances[index] + + cluster = Cluster(protocol_version=4) + for host in hosts: + cluster.metadata.add_or_return_host(host) + + session = Session(cluster, hosts) + session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL)) + session._pools = {} + for host, schema_version in zip(hosts, schema_versions): + connection = Mock(endpoint=host.endpoint) + connection.future_outcomes = [schema_version] + session._pools[host] = self.MockPool(host, connection) + connections[host] = connection + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + connection = connections[host] + outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0] + return self._mock_schema_future(outcome) + + session.execute_async = Mock(side_effect=execute_async) + + return session, hosts, connections + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -339,6 +453,104 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + @mock_session_pools + def test_wait_for_schema_agreement_default_scope_queries_all_connected_hosts(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1) + + for host in hosts: + assert self._host_query_count(session, host) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_): + session, hosts, connections = self._new_schema_agreement_session(["a", "b"]) + clock = self.FakeTime() + connections[hosts[1]].future_outcomes = ["b", "a"] + + with patch('cassandra.cluster.time', new=clock): + assert session.wait_for_schema_agreement(wait_time=1) + for host in hosts: + assert self._host_query_count(session, host) == 2 + assert clock.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_): + session, hosts, connections = self._new_schema_agreement_session(["a", "a"]) + clock = self.FakeTime() + connections[hosts[1]].future_outcomes = [ + ConnectionBusy("connection overloaded"), + "a"] + + with patch('cassandra.cluster.time', new=clock): + assert session.wait_for_schema_agreement(wait_time=1) + for host in hosts: + assert self._host_query_count(session, host) == 2 + assert clock.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_): + session, hosts, _ = self._new_schema_agreement_session(["a"]) + + unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + unconnected_host.set_up() + session.cluster.metadata.add_or_return_host(unconnected_host) + + assert session.wait_for_schema_agreement(wait_time=1) + assert self._host_query_count(session, hosts[0]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_queries_hosts_in_order(self, *_): + session, hosts, _ = self._new_schema_agreement_session(["a"] * 11) + + assert session.wait_for_schema_agreement(wait_time=1) + assert [call.kwargs['host'] for call in session.execute_async.call_args_list] == list(hosts) + + @mock_session_pools + def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.RACK) + + assert self._host_query_count(session, hosts[0]) == 1 + assert self._host_query_count(session, hosts[1]) == 0 + assert self._host_query_count(session, hosts[2]) == 0 + + @mock_session_pools + def test_wait_for_schema_agreement_cluster_scope_skips_ignored_hosts(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.IGNORED, HostDistance.LOCAL]) + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.CLUSTER) + + assert self._host_query_count(session, hosts[0]) == 0 + assert self._host_query_count(session, hosts[1]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_cluster_scope_excludes_hosts_with_unknown_status(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL]) + + hosts[0].is_up = None + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.CLUSTER) + + assert self._host_query_count(session, hosts[0]) == 0 + assert self._host_query_count(session, hosts[1]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): + session, _, _ = self._new_schema_agreement_session(["a"]) + + with pytest.raises(ValueError): + session.wait_for_schema_agreement(wait_time=1, scope='planet') + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..fd62323f33 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,7 +15,7 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call, patch from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS @@ -210,16 +210,27 @@ def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing """ - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() # the control connection should not have slept at all assert self.time.clock == 0 + @patch('cassandra.cluster.warn') + def test_wait_for_schema_agreement_warns_about_deprecation(self, mocked_warn): + assert self.control_connection.wait_for_schema_agreement() + + mocked_warn.assert_called_once() + warning_args, warning_kwargs = mocked_warn.call_args + assert 'ControlConnection.wait_for_schema_agreement is deprecated' in str(warning_args[0]) + assert 'Use Session.wait_for_schema_agreement instead.' in str(warning_args[0]) + assert warning_args[1] is DeprecationWarning + assert warning_kwargs['stacklevel'] == 2 + def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): """ wait_for_schema_agreement uses preloaded results if given for shared table queries """ preloaded_results = self._matching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 # the connection should not have made any queries if given preloaded results @@ -230,7 +241,7 @@ def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_ wait_for_schema_agreement requery if schema does not match using preloaded results """ preloaded_results = self._nonmatching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 assert self.connection.wait_for_responses.call_count == 1 @@ -241,7 +252,7 @@ def test_wait_for_schema_agreement_fails(self): """ # change the schema version on one node self.connection.peer_results[1][1][2] = 'b' - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() # the control connection should have slept until it hit the limit assert self.time.clock >= self.cluster.max_schema_agreement_wait @@ -262,7 +273,7 @@ def test_wait_for_schema_agreement_skipping(self): self.connection.peer_results[1][1][3] = 'c' self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 def test_wait_for_schema_agreement_rpc_lookup(self): @@ -279,12 +290,12 @@ def test_wait_for_schema_agreement_rpc_lookup(self): # even though the new host has a different schema version, it's # marked as down, so the control connection shouldn't care - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 # but once we mark it up, the control connection will care host.is_up = True - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() assert self.time.clock >= self.cluster.max_schema_agreement_wait @@ -299,7 +310,7 @@ def test_wait_for_schema_agreement_none_timeout(self): status_event_refresh_window=0) cc._connection = self.connection cc._time = self.time - assert cc.wait_for_schema_agreement() + assert cc._wait_for_schema_agreement() def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() @@ -441,7 +452,8 @@ def bad_wait_for_responses(*args, **kwargs): self.control_connection.refresh_node_list_and_token_map() self.cluster.executor.submit.assert_called_with(self.control_connection._reconnect) - def test_refresh_schema_timeout(self): + @patch('cassandra.cluster.warn') + def test_refresh_schema_timeout(self, mocked_warn): def bad_wait_for_responses(*args, **kwargs): self.time.sleep(kwargs['timeout']) @@ -451,6 +463,7 @@ def bad_wait_for_responses(*args, **kwargs): self.control_connection.refresh_schema() assert self.connection.wait_for_responses.call_count == self.cluster.max_schema_agreement_wait / self.control_connection._timeout assert self.connection.wait_for_responses.call_args[1]['timeout'] == self.control_connection._timeout + mocked_warn.assert_not_called() def test_handle_topology_change(self): event = { diff --git a/tests/unit/test_session_schema_agreement.py b/tests/unit/test_session_schema_agreement.py new file mode 100644 index 0000000000..ffad687fcc --- /dev/null +++ b/tests/unit/test_session_schema_agreement.py @@ -0,0 +1,204 @@ +from datetime import timedelta +from types import SimpleNamespace +from unittest.mock import Mock +import uuid + +import pytest + +import cassandra.cluster as cluster_module +from cassandra.connection import ConnectionBusy +from cassandra.cluster import ControlConnection, Session, ResultSet +from cassandra.policies import HostDistance, SimpleConvictionPolicy +from cassandra.pool import Host +from cassandra.util import maybe_add_timeout_to_query + + +class FakeTime: + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + +class MockPool: + def __init__(self, host): + self.host = host + self.is_shutdown = False + + +class MockSchemaVersionFuture: + def __init__(self, outcome, auto_complete=True): + self._outcome = outcome + self._auto_complete = auto_complete + self._delivered = False + self._callback_state = None + self._col_names = ("schema_version",) + self._col_types = None + self.has_more_pages = False + self._continuous_paging_session = None + + def _deliver(self): + if self._delivered or self._callback_state is None: + return + + self._delivered = True + callback, errback, callback_args, callback_kwargs, errback_args, errback_kwargs = self._callback_state + if isinstance(self._outcome, Exception): + errback(self._outcome, *errback_args, **errback_kwargs) + else: + row = SimpleNamespace(schema_version=self._outcome) + callback([row], *callback_args, **callback_kwargs) + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + self._callback_state = ( + callback, + errback, + callback_args, + callback_kwargs or {}, + errback_args, + errback_kwargs or {}, + ) + if self._auto_complete: + self._deliver() + return self + + def complete(self): + self._deliver() + + def result(self): + if isinstance(self._outcome, Exception): + raise self._outcome + return ResultSet(self, [SimpleNamespace(schema_version=self._outcome)]) + + +def _host_query_count(session, target_host): + return sum(1 for call in session.execute_async.call_args_list if call.kwargs.get("host") is target_host) + + +def _new_session(schema_versions, distances=None, metadata_request_timeout=timedelta(seconds=2), timeout=2.0): + hosts = [] + connections = {} + distance_map = {} + + if distances is None: + distances = [HostDistance.LOCAL] * len(schema_versions) + + for index, schema_version in enumerate(schema_versions): + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + hosts.append(host) + distance_map[host] = distances[index] + + cluster = SimpleNamespace( + max_schema_agreement_wait=10, + control_connection=SimpleNamespace( + _timeout=timeout, + _metadata_request_timeout=metadata_request_timeout, + ), + ) + + session = Session.__new__(Session) + session.cluster = cluster + session._profile_manager = SimpleNamespace(distance=lambda host: distance_map.get(host, HostDistance.LOCAL)) + session._pools = {} + session.is_shutdown = False + + for host, schema_version in zip(hosts, schema_versions): + connection = Mock(endpoint=host.endpoint) + connection.future_outcomes = [schema_version] + session._pools[host] = MockPool(host) + connections[host] = connection + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + connection = connections[host] + outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0] + return MockSchemaVersionFuture(outcome) + + session.execute_async = Mock(side_effect=execute_async) + + return session, hosts, connections + + +def test_wait_for_schema_agreement_retries_with_module_time(monkeypatch): + session, hosts, connections = _new_session(["a", "b"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + connections[hosts[1]].future_outcomes = ["b", "a"] + + assert session.wait_for_schema_agreement(wait_time=1) + assert clock.clock == pytest.approx(0.2) + for host in hosts: + assert _host_query_count(session, host) == 2 + + +@pytest.mark.parametrize("wait_time", [0, -1]) +def test_wait_for_schema_agreement_rejects_non_positive_wait_time(wait_time): + session, _, _ = _new_session(["a"]) + + with pytest.raises(ValueError, match="wait_time must be greater than 0"): + session.wait_for_schema_agreement(wait_time=wait_time) + + assert session.execute_async.call_count == 0 + + +def test_wait_for_schema_agreement_returns_false_when_no_hosts_match_scope(monkeypatch): + session, _, _ = _new_session(["a"], distances=[HostDistance.IGNORED]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + + assert session.wait_for_schema_agreement(wait_time=1) is False + assert session.execute_async.call_count == 0 + assert clock.clock == pytest.approx(1.0) + + +def test_wait_for_schema_agreement_uses_host_targeted_session_queries(): + session, hosts, _ = _new_session(["a", "a"]) + + assert session.wait_for_schema_agreement(wait_time=0.1) + + expected_query = maybe_add_timeout_to_query( + ControlConnection._SELECT_SCHEMA_LOCAL, + timedelta(seconds=2), + ) + assert session.execute_async.call_count == 2 + assert [call.args[0] for call in session.execute_async.call_args_list] == [expected_query, expected_query] + assert [call.kwargs["host"] for call in session.execute_async.call_args_list] == hosts + for call in session.execute_async.call_args_list: + assert 0 < call.kwargs["timeout"] <= 0.1 + + +def test_wait_for_schema_agreement_retries_after_host_targeted_query_error(monkeypatch): + session, hosts, connections = _new_session(["a", "a"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + connections[hosts[1]].future_outcomes = [ConnectionBusy("connection overloaded"), "a"] + + assert session.wait_for_schema_agreement(wait_time=1) + assert clock.clock == pytest.approx(0.2) + for host in hosts: + assert _host_query_count(session, host) == 2 + + +def test_wait_for_schema_agreement_queries_hosts_in_order_under_one_deadline(monkeypatch): + session, hosts, _ = _new_session(["a", "a", "a"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + clock.sleep(0.01) + return MockSchemaVersionFuture("a") + + session.execute_async = Mock(side_effect=execute_async) + + assert session.wait_for_schema_agreement(wait_time=1) + assert [call.kwargs["host"] for call in session.execute_async.call_args_list] == hosts