diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..6f0802c2e2 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -194,6 +194,7 @@ def _connection_reduce_fn(val,import_fn): _GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') _NOT_SET = object() +_SCHEMA_AGREEMENT_MISMATCHES_ATTR = '_schema_agreement_mismatches' class NoHostAvailable(Exception): @@ -2607,6 +2608,8 @@ def default_serial_consistency_level(self, cl): _metrics = None _request_init_callbacks = None _graph_paging_available = False + _time = time + _schema_agreement_parallelism = 10 def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster @@ -3374,6 +3377,177 @@ def pool_finished_setting_keyspace(pool, host_errors): for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + def wait_for_schema_agreement(self, wait_time=None, scope='dc'): + """ + Wait for connected hosts in the selected scope to report the same + schema version from ``system.local``. + + By default, the timeout for this operation is governed by + :attr:`~.Cluster.max_schema_agreement_wait` and + :attr:`~.Cluster.control_connection_timeout`. + + Passing ``wait_time`` here overrides + :attr:`~.Cluster.max_schema_agreement_wait`. Setting ``wait_time <= 0`` + will bypass schema agreement waits. + + ``scope`` determines which connected hosts participate in the check. + Accepted values are ``'rack'``, ``'dc'``, and ``'cluster'``. The + default ``'dc'`` scope queries connected hosts in the local rack and + local datacenter. ``'rack'`` narrows the check to connected hosts in + the local rack only. ``'cluster'`` queries every host this session has + a live connection pool for, across all datacenters. + + :param wait_time: Override for + :attr:`~.Cluster.max_schema_agreement_wait`. + :param scope: Restricts the check to connected hosts in the local rack, + local datacenter, or whole connected cluster. + :returns: ``True`` when the selected connected hosts agree on schema, + otherwise ``False``. + :raises ValueError: If ``scope`` is not one of ``'rack'``, ``'dc'``, + or ``'cluster'``. + """ + total_timeout = wait_time if wait_time is not None else self.cluster.max_schema_agreement_wait + if total_timeout <= 0: + return True + + scope = self._normalize_schema_agreement_scope(scope) + + deadline = self._time.time() + total_timeout + schema_mismatches = None + + while self._time.time() < deadline: + schema_mismatches = self._get_schema_mismatches_for_scope(deadline, scope) + if schema_mismatches is None: + return True + + log.debug("[session] Local schemas mismatched, trying again") + remaining = deadline - self._time.time() + if remaining > 0: + self._time.sleep(min(0.2, remaining)) + + log.warning("Local nodes are reporting a schema disagreement: %s", schema_mismatches) + return False + + def _get_schema_mismatches_for_scope(self, deadline, scope): + hosts = self._get_schema_agreement_hosts(scope) + versions = defaultdict(set) + errors = {} + + if not hosts: + return {'unavailable': 'No local hosts available'} + + cl = ConsistencyLevel.ONE + metadata_request_timeout = self.cluster.control_connection._metadata_request_timeout + query = QueryMessage( + query=maybe_add_timeout_to_query(ControlConnection._SELECT_SCHEMA_LOCAL, metadata_request_timeout), + consistency_level=cl) + + parallelism = max(1, min(self._schema_agreement_parallelism, len(hosts))) + for offset in range(0, len(hosts), parallelism): + remaining = deadline - self._time.time() + if remaining <= 0: + for host in hosts[offset:]: + errors[host.endpoint] = "Timed out before querying host" + break + + futures = {} + for host in hosts[offset:offset + parallelism]: + future = self.submit(self._query_local_schema_version, host, query, deadline) + if future is None: + errors[host.endpoint] = "Schema agreement executor unavailable" + continue + futures[future] = host + + if not futures: + continue + + remaining = deadline - self._time.time() + if remaining <= 0: + for future, host in futures.items(): + future.cancel() + errors[host.endpoint] = "Timed out before querying host" + for host in hosts[offset + parallelism:]: + errors[host.endpoint] = "Timed out before querying host" + break + + done, not_done = wait_futures(tuple(futures), timeout=remaining) + for future in not_done: + future.cancel() + host = futures[future] + errors[host.endpoint] = "Timed out before querying host" + + for future in done: + host = futures[future] + schema_version, error = future.result() + if error is not None: + errors[host.endpoint] = error + continue + versions[schema_version].add(host.endpoint) + + if not_done: + for host in hosts[offset + parallelism:]: + errors[host.endpoint] = "Timed out before querying host" + break + + if len(versions) == 1 and None not in versions and not errors: + log.debug("[session] Local schemas match") + return None + + mismatches = dict((version, list(nodes)) for version, nodes in versions.items()) + if errors: + mismatches['unavailable'] = dict((endpoint, str(error)) for endpoint, error in errors.items()) + return mismatches + + def _get_schema_agreement_hosts(self, scope): + allowed_distances = { + 'rack': (HostDistance.LOCAL_RACK,), + 'dc': (HostDistance.LOCAL_RACK, HostDistance.LOCAL), + } + return tuple( + host for host, pool in tuple(self._pools.items()) + if host.is_up is not False + and not pool.is_shutdown + and (scope == 'cluster' or self._profile_manager.distance(host) in allowed_distances[scope])) + + def _normalize_schema_agreement_scope(self, scope): + normalized_scope = str(scope).strip().lower().replace('_', '').replace(' ', '') + normalized_scope = { + 'wholecluster': 'cluster', + 'datacenter': 'dc', + }.get(normalized_scope, normalized_scope) + + if normalized_scope not in ('rack', 'dc', 'cluster'): + raise ValueError("Invalid schema agreement scope: %s" % (scope,)) + + return normalized_scope + + def _query_local_schema_version(self, host, query, deadline): + remaining = deadline - self._time.time() + if remaining <= 0: + return None, "Timed out before querying host" + + pool = self._pools.get(host) + if not pool or pool.is_shutdown: + return None, "No active connection pool" + + try: + connection = pool._get_connection_for_routing_key() + query_timeout = self._schema_agreement_query_timeout(remaining) + local_result = connection.wait_for_response(query, timeout=query_timeout) + except OperationTimedOut as timeout: + log.debug("[session] Timed out waiting for schema version from %s: %s", host, timeout) + return None, timeout + except (ConnectionException, NoConnectionsAvailable, ConnectionBusy) as exc: + log.debug("[session] Error querying schema version from %s: %s", host, exc) + return None, exc + + rows = dict_factory(local_result.column_names, local_result.parsed_rows) + return (rows[0].get("schema_version") if rows else None), None + + def _schema_agreement_query_timeout(self, remaining): + control_timeout = self.cluster.control_connection._timeout + return min(control_timeout, remaining) if control_timeout is not None else remaining + def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new @@ -3786,9 +3960,9 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w if self._cluster.is_shutdown: return False - agreed = self.wait_for_schema_agreement(connection, - preloaded_results=preloaded_results, - wait_time=schema_agreement_wait) + agreed = self._wait_for_schema_agreement(connection, + preloaded_results=preloaded_results, + wait_time=schema_agreement_wait) if not self._schema_meta_enabled and not force: log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") @@ -4079,6 +4253,13 @@ def _handle_schema_change(self, event): self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + warn("ControlConnection.wait_for_schema_agreement is deprecated and will be removed in 4.0. " + "Use Session.wait_for_schema_agreement instead.", DeprecationWarning) + return self._wait_for_schema_agreement(connection=connection, + preloaded_results=preloaded_results, + wait_time=wait_time) + + def _wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: @@ -4094,6 +4275,13 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai if not connection: connection = self._connection + if not connection: + fallback = self._wait_for_schema_agreement_through_session(total_timeout) + if fallback is not None: + return fallback + return None + + schema_mismatches = None if preloaded_results: log.debug("[control connection] Attempting to use preloaded results for schema agreement") @@ -4108,7 +4296,6 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai start = self._time.time() elapsed = 0 cl = ConsistencyLevel.ONE - schema_mismatches = None select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) while elapsed < total_timeout: @@ -4126,13 +4313,40 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai "response during schema agreement check: %s", timeout) elapsed = self._time.time() - start continue - except ConnectionShutdown: - if self._is_shutdown: + except ConnectionBusy as exc: + elapsed = self._time.time() - start + if schema_mismatches is None: + fallback_wait = total_timeout - elapsed + fallback = self._wait_for_schema_agreement_through_session(fallback_wait) + if fallback is not None: + return fallback + raise + + log.debug("[control connection] Connection busy during schema agreement check: %s", + exc) + remaining = total_timeout - elapsed + if remaining > 0: + self._time.sleep(min(0.2, remaining)) + elapsed = self._time.time() - start + continue + except ConnectionException as exc: + if isinstance(exc, ConnectionShutdown) and self._is_shutdown: log.debug("[control connection] Aborting wait for schema match due to shutdown") return None - else: + + elapsed = self._time.time() - start + if schema_mismatches is not None: + log.debug("[control connection] Error during schema agreement check after mismatch: %s", + exc) + setattr(exc, _SCHEMA_AGREEMENT_MISMATCHES_ATTR, schema_mismatches) raise + fallback_wait = total_timeout - elapsed + fallback = self._wait_for_schema_agreement_through_session(fallback_wait) + if fallback is not None: + return fallback + raise + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) if schema_mismatches is None: return True @@ -4145,6 +4359,31 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai connection.endpoint, schema_mismatches) return False + def _wait_for_schema_agreement_through_session(self, wait_time): + if wait_time <= 0: + return None + + try: + sessions = tuple(self._cluster.sessions) + except (AttributeError, ReferenceError): + return None + + deadline = self._time.time() + wait_time + for session in sessions: + if not getattr(session, 'is_shutdown', False): + remaining = deadline - self._time.time() + if remaining <= 0: + return None + log.debug("[control connection] Falling back to session schema agreement check") + try: + fallback = session.wait_for_schema_agreement(wait_time=remaining) + except Exception: + log.debug("[control connection] Session schema agreement fallback failed", exc_info=True) + continue + if fallback is not None: + return fallback + return None + def _get_schema_mismatches(self, peers_result, local_result, local_address): peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) @@ -4360,12 +4599,42 @@ def _log_if_failed(self, future): def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: - log.debug("Refreshing schema in response to schema change. " - "%s", kwargs) - response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) - except Exception: - log.exception("Exception refreshing schema in response to schema change:") - response_future.session.submit(control_conn.refresh_schema, **kwargs) + log.debug("Refreshing schema in response to schema change. %s", kwargs) + use_session_fallback = False + try: + response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) + except Exception as exc: + log.exception("Exception refreshing schema in response to schema change:") + response_future.is_schema_agreed = False + schema_mismatches = getattr(exc, _SCHEMA_AGREEMENT_MISMATCHES_ATTR, _NOT_SET) + if schema_mismatches is not _NOT_SET: + log.debug("Skipping session schema agreement fallback after control connection " + "reported a schema disagreement: %s", + schema_mismatches) + response_future.session.submit(control_conn.refresh_schema, **kwargs) + else: + use_session_fallback = True + + if use_session_fallback: + log.debug("Falling back to session schema agreement check") + try: + response_future.is_schema_agreed = response_future.session.wait_for_schema_agreement() + except Exception: + log.exception("Exception waiting for schema agreement through session:") + response_future.is_schema_agreed = False + + if response_future.is_schema_agreed: + try: + refreshed = control_conn.refresh_schema(schema_agreement_wait=0, **kwargs) + except Exception: + log.exception("Exception refreshing schema after session agreement fallback:") + refreshed = False + + if not refreshed: + response_future.is_schema_agreed = False + response_future.session.submit(control_conn.refresh_schema, **kwargs) + else: + response_future.session.submit(control_conn.refresh_schema, **kwargs) finally: response_future._set_final_result(None) diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..99bcefe4b4 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1262,9 +1262,16 @@ def wait_for_responses(self, *msgs, **kwargs): self.in_flight += available for i, request_id in enumerate(request_ids): - self.send_msg(msgs[messages_sent + i], - request_id, - partial(waiter.got_response, index=messages_sent + i)) + try: + self.send_msg(msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i)) + except (ConnectionBusy, ConnectionShutdown): + unsent_request_ids = request_ids[i:] + with self.lock: + self.in_flight -= len(unsent_request_ids) + self.request_ids.extend(unsent_request_ids) + raise messages_sent += available if messages_sent == len(msgs): diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 51f03f3d97..de8518d271 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -169,6 +169,8 @@ Clusters and Sessions .. automethod:: set_keyspace(keyspace) + .. automethod:: wait_for_schema_agreement + .. automethod:: get_execution_profile .. automethod:: execution_profile_clone_update diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index f892acba52..3b4dcd33d5 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -158,4 +158,4 @@ def check_and_wait_for_agreement(self, session, rs, exepected): time.sleep(1) assert rs.response_future.is_schema_agreed == exepected if not rs.response_future.is_schema_agreed: - session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) + session.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index e608a9610b..18f3dfb298 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -147,7 +147,7 @@ def test_can_register_udt_before_connecting(self): c.register_user_type("udt_test_register_before_connecting2", "user", User2) s = c.connect(wait_for_all_pools=True) - c.control_connection.wait_for_schema_agreement() + s.wait_for_schema_agreement() s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0") diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..f0f7aeb288 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +from concurrent.futures import Future import logging import socket @@ -23,6 +24,7 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import ConnectionBusy from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -247,11 +249,64 @@ def test_event_delay_timing(self, *_): class SessionTest(unittest.TestCase): + class FakeTime(object): + + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + class MockPool(object): + + def __init__(self, host, connection): + self.host = host + self.host_distance = HostDistance.LOCAL + self.is_shutdown = False + self.connection = connection + + def _get_connection_for_routing_key(self): + return self.connection + def setUp(self): if connection_class is None: raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + def _mock_schema_response(self, schema_version): + response = Mock() + response.column_names = ["schema_version"] + response.parsed_rows = [[schema_version]] + return response + + def _new_schema_agreement_session(self, schema_versions, distances=None): + hosts = [] + distance_map = {} + if distances is None: + distances = [HostDistance.LOCAL] * len(schema_versions) + for index, schema_version in enumerate(schema_versions): + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + hosts.append(host) + distance_map[host] = distances[index] + + cluster = Cluster(protocol_version=4) + for host in hosts: + cluster.metadata.add_or_return_host(host) + + session = Session(cluster, hosts) + session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL)) + session._pools = {} + for host, schema_version in zip(hosts, schema_versions): + connection = Mock(endpoint=host.endpoint) + connection.wait_for_response.return_value = self._mock_schema_response(schema_version) + session._pools[host] = self.MockPool(host, connection) + + return session, hosts + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -339,6 +394,104 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + @mock_session_pools + def test_wait_for_schema_agreement_queries_all_local_hosts(self, *_): + session, hosts = self._new_schema_agreement_session(["a", "a"]) + + assert session.wait_for_schema_agreement(wait_time=1) + + for host in hosts: + connection = session._pools[host].connection + connection.wait_for_response.assert_called_once() + + @mock_session_pools + def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_): + session, hosts = self._new_schema_agreement_session(["a", "b"]) + session._time = self.FakeTime() + second_connection = session._pools[hosts[1]].connection + second_connection.wait_for_response.side_effect = [ + self._mock_schema_response("b"), + self._mock_schema_response("a")] + + assert session.wait_for_schema_agreement(wait_time=1) + assert second_connection.wait_for_response.call_count == 2 + assert session._time.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_): + session, hosts = self._new_schema_agreement_session(["a", "a"]) + session._time = self.FakeTime() + busy_connection = session._pools[hosts[1]].connection + busy_connection.wait_for_response.side_effect = [ + ConnectionBusy("connection overloaded"), + self._mock_schema_response("a")] + + assert session.wait_for_schema_agreement(wait_time=1) + assert busy_connection.wait_for_response.call_count == 2 + assert session._time.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_): + session, hosts = self._new_schema_agreement_session(["a"]) + session._time = self.FakeTime() + + unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + unconnected_host.set_up() + session.cluster.metadata.add_or_return_host(unconnected_host) + + assert session.wait_for_schema_agreement(wait_time=1) + session._pools[hosts[0]].connection.wait_for_response.assert_called() + + @mock_session_pools + @patch('cassandra.cluster.wait_futures') + def test_wait_for_schema_agreement_limits_parallel_queries_to_default(self, mocked_wait_futures, *_): + session, _ = self._new_schema_agreement_session(["a"] * 11) + batch_sizes = [] + + def submit(fn, host, query, deadline): + future = Future() + future.set_result(fn(host, query, deadline)) + return future + + def fake_wait(futures, timeout=None, return_when=None): + batch_sizes.append(len(futures)) + return set(futures), set() + + session.submit = Mock(side_effect=submit) + mocked_wait_futures.side_effect = fake_wait + + assert session.wait_for_schema_agreement(wait_time=1) + assert batch_sizes == [10, 1] + + @mock_session_pools + def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_): + session, hosts = self._new_schema_agreement_session( + ["a", "a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1, scope='rack') + + session._pools[hosts[0]].connection.wait_for_response.assert_called_once() + session._pools[hosts[1]].connection.wait_for_response.assert_not_called() + session._pools[hosts[2]].connection.wait_for_response.assert_not_called() + + @mock_session_pools + def test_wait_for_schema_agreement_cluster_scope_queries_all_connected_hosts(self, *_): + session, hosts = self._new_schema_agreement_session( + ["a", "a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1, scope='cluster') + + for host in hosts: + session._pools[host].connection.wait_for_response.assert_called_once() + + @mock_session_pools + def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): + session, _ = self._new_schema_agreement_session(["a"]) + + with pytest.raises(ValueError): + session.wait_for_schema_agreement(wait_time=1, scope='planet') class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..22ba514fdc 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,8 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionBusy, ConnectionException, ConnectionShutdown, DefaultEndPoint, + ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler, ResultMessage, @@ -363,6 +364,19 @@ def test_wait_for_responses_shutdown_includes_last_error(self): assert "already closed" in error_message assert "Bad file descriptor" in error_message + def test_wait_for_responses_releases_request_id_when_send_fails(self): + c = self.make_connection() + c._socket_writable = False + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionBusy): + c.wait_for_responses(Mock()) + + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests + @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..a95131437f 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,13 +15,16 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call, patch from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import (ControlConnection, _Scheduler, ProfileManager, + EXEC_PROFILE_DEFAULT, ExecutionProfile, + refresh_schema_and_set_result) from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import (EndPoint, DefaultEndPoint, DefaultEndPointFactory, + ConnectionException, ConnectionShutdown, ConnectionBusy) from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -210,16 +213,25 @@ def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing """ - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() # the control connection should not have slept at all assert self.time.clock == 0 + @patch('cassandra.cluster.warn') + def test_wait_for_schema_agreement_warns_about_deprecation(self, mocked_warn): + assert self.control_connection.wait_for_schema_agreement() + + assert len(mocked_warn.mock_calls) == 1 + warning_args = tuple(mocked_warn.mock_calls[0])[1] + assert 'ControlConnection.wait_for_schema_agreement is deprecated' in str(warning_args[0]) + assert warning_args[1] is DeprecationWarning + def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): """ wait_for_schema_agreement uses preloaded results if given for shared table queries """ preloaded_results = self._matching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 # the connection should not have made any queries if given preloaded results @@ -230,7 +242,7 @@ def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_ wait_for_schema_agreement requery if schema does not match using preloaded results """ preloaded_results = self._nonmatching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 assert self.connection.wait_for_responses.call_count == 1 @@ -241,10 +253,146 @@ def test_wait_for_schema_agreement_fails(self): """ # change the schema version on one node self.connection.peer_results[1][1][2] = 'b' - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() # the control connection should have slept until it hit the limit assert self.time.clock >= self.cluster.max_schema_agreement_wait + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_closes(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionShutdown("closed") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_is_busy(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionBusy("overloaded") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_errors(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionException("write failed") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_session_fallback_skips_failing_sessions(self): + failing_session = Mock(is_shutdown=False) + failing_session.wait_for_schema_agreement.side_effect = ConnectionException("session broken") + healthy_session = Mock(is_shutdown=False) + healthy_session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [failing_session, healthy_session] + self.connection.wait_for_responses.side_effect = ConnectionBusy("overloaded") + + assert self.control_connection.wait_for_schema_agreement() + failing_session.wait_for_schema_agreement.assert_called_once_with( + wait_time=self.cluster.max_schema_agreement_wait) + healthy_session.wait_for_schema_agreement.assert_called_once_with( + wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_subtracts_elapsed_time_before_session_fallback(self): + session = Mock(is_shutdown=False) + + def wait_for_responses(*args, **kwargs): + self.time.sleep(3) + raise ConnectionShutdown("closed") + + def wait_for_schema_agreement(wait_time=None): + self.time.sleep(wait_time) + return False + + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = wait_for_responses + session.wait_for_schema_agreement.side_effect = wait_for_schema_agreement + + assert not self.control_connection.wait_for_schema_agreement() + assert self.time.clock == self.cluster.max_schema_agreement_wait + + def test_wait_for_schema_agreement_does_not_accept_session_fallback_after_known_mismatch(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.peer_results[1][1][2] = 'b' + + assert not self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_not_called() + + def test_wait_for_schema_agreement_retries_control_connection_after_mismatch_then_busy(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + matching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionBusy("overloaded"), + _node_meta_results(self.connection.local_results, (peer_columns, matching_peer_rows))] + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_not_called() + assert self.connection.wait_for_responses.call_count == 3 + + def test_wait_for_schema_agreement_raises_connection_error_after_mismatch(self): + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionShutdown("closed")] + + with self.assertRaises(ConnectionShutdown): + self.control_connection.wait_for_schema_agreement() + + def test_schema_change_refresh_does_not_session_fallback_after_mismatch_then_connection_error(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.cluster.metadata.refresh = Mock() + + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionShutdown("closed")] + + response_future = Mock() + response_future.session = session + event = {'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + + refresh_schema_and_set_result(self.control_connection, response_future, self.connection, **event) + + session.wait_for_schema_agreement.assert_not_called() + self.cluster.metadata.refresh.assert_not_called() + assert not response_future.is_schema_agreed + response_future._set_final_result.assert_called_once_with(None) + + def test_wait_for_schema_agreement_does_not_exceed_configured_wait_with_session_fallback(self): + session = Mock(is_shutdown=False) + + def wait_for_schema_agreement(wait_time=None): + self.time.sleep(wait_time) + return False + + session.wait_for_schema_agreement.side_effect = wait_for_schema_agreement + self.cluster.sessions = [session] + self.connection.peer_results[1][1][2] = 'b' + + assert not self.control_connection.wait_for_schema_agreement() + assert self.time.clock < self.cluster.max_schema_agreement_wait + 0.2 + def test_wait_for_schema_agreement_skipping(self): """ If rpc_address or schema_version isn't set, the host should be skipped @@ -262,7 +410,7 @@ def test_wait_for_schema_agreement_skipping(self): self.connection.peer_results[1][1][3] = 'c' self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 def test_wait_for_schema_agreement_rpc_lookup(self): @@ -279,12 +427,12 @@ def test_wait_for_schema_agreement_rpc_lookup(self): # even though the new host has a different schema version, it's # marked as down, so the control connection shouldn't care - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 # but once we mark it up, the control connection will care host.is_up = True - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() assert self.time.clock >= self.cluster.max_schema_agreement_wait @@ -299,7 +447,7 @@ def test_wait_for_schema_agreement_none_timeout(self): status_event_refresh_window=0) cc._connection = self.connection cc._time = self.time - assert cc.wait_for_schema_agreement() + assert cc._wait_for_schema_agreement() def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..5df70f5630 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -19,7 +19,7 @@ from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion, refresh_schema_and_set_result from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, @@ -123,6 +123,63 @@ def test_schema_change_result(self): rf._set_result(None, connection, None, result) session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results) + def test_schema_change_refresh_falls_back_to_session_agreement(self): + session = self.make_session() + session.wait_for_schema_agreement.return_value = True + control_conn = Mock() + control_conn._refresh_schema.side_effect = Exception("closed") + control_conn.refresh_schema.return_value = True + rf = self.make_response_future(session) + connection = Mock() + event_results = {'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + + refresh_schema_and_set_result(control_conn, rf, connection, **event_results) + + control_conn._refresh_schema.assert_called_once_with(connection, **event_results) + session.wait_for_schema_agreement.assert_called_once_with() + control_conn.refresh_schema.assert_called_once_with(schema_agreement_wait=0, **event_results) + assert rf.is_schema_agreed + assert not rf.result() + + def test_schema_change_refresh_marks_disagreed_when_fallback_refresh_fails(self): + session = self.make_session() + session.wait_for_schema_agreement.return_value = True + control_conn = Mock() + control_conn._refresh_schema.side_effect = Exception("closed") + control_conn.refresh_schema.return_value = False + rf = self.make_response_future(session) + connection = Mock() + event_results = {'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + + refresh_schema_and_set_result(control_conn, rf, connection, **event_results) + + control_conn._refresh_schema.assert_called_once_with(connection, **event_results) + session.wait_for_schema_agreement.assert_called_once_with() + control_conn.refresh_schema.assert_called_once_with(schema_agreement_wait=0, **event_results) + session.submit.assert_called_once_with(control_conn.refresh_schema, **event_results) + assert not rf.is_schema_agreed + assert not rf.result() + + def test_schema_change_refresh_does_not_fall_back_after_schema_disagreement(self): + session = self.make_session() + session.wait_for_schema_agreement.return_value = True + control_conn = Mock() + control_conn._refresh_schema.return_value = False + rf = self.make_response_future(session) + connection = Mock() + event_results = {'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + + refresh_schema_and_set_result(control_conn, rf, connection, **event_results) + + control_conn._refresh_schema.assert_called_once_with(connection, **event_results) + session.wait_for_schema_agreement.assert_not_called() + control_conn.refresh_schema.assert_not_called() + assert not rf.is_schema_agreed + assert not rf.result() + def test_other_result_message_kind(self): session = self.make_session() rf = self.make_response_future(session)