Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 282 additions & 13 deletions cassandra/cluster.py

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,9 +1262,16 @@ def wait_for_responses(self, *msgs, **kwargs):
self.in_flight += available

for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
try:
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
except (ConnectionBusy, ConnectionShutdown):
unsent_request_ids = request_ids[i:]
with self.lock:
self.in_flight -= len(unsent_request_ids)
self.request_ids.extend(unsent_request_ids)
raise
messages_sent += available

if messages_sent == len(msgs):
Expand Down
2 changes: 2 additions & 0 deletions docs/api/cassandra/cluster.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ Clusters and Sessions

.. automethod:: set_keyspace(keyspace)

.. automethod:: wait_for_schema_agreement

.. automethod:: get_execution_profile

.. automethod:: execution_profile_clone_update
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/long/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@ def check_and_wait_for_agreement(self, session, rs, exepected):
time.sleep(1)
assert rs.response_future.is_schema_agreed == exepected
if not rs.response_future.is_schema_agreed:
session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000)
session.wait_for_schema_agreement(wait_time=1000)
2 changes: 1 addition & 1 deletion tests/integration/standard/test_udts.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_can_register_udt_before_connecting(self):
c.register_user_type("udt_test_register_before_connecting2", "user", User2)

s = c.connect(wait_for_all_pools=True)
c.control_connection.wait_for_schema_agreement()
s.wait_for_schema_agreement()

s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob')))
result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0")
Expand Down
153 changes: 153 additions & 0 deletions tests/unit/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from concurrent.futures import Future

import logging
import socket
Expand All @@ -23,6 +24,7 @@
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
from cassandra.connection import ConnectionBusy
from cassandra.pool import Host
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
Expand Down Expand Up @@ -247,11 +249,64 @@ def test_event_delay_timing(self, *_):


class SessionTest(unittest.TestCase):
class FakeTime(object):

def __init__(self):
self.clock = 0

def time(self):
return self.clock

def sleep(self, amount):
self.clock += amount

class MockPool(object):

def __init__(self, host, connection):
self.host = host
self.host_distance = HostDistance.LOCAL
self.is_shutdown = False
self.connection = connection

def _get_connection_for_routing_key(self):
return self.connection

def setUp(self):
if connection_class is None:
raise unittest.SkipTest('libev does not appear to be installed correctly')
connection_class.initialize_reactor()

def _mock_schema_response(self, schema_version):
response = Mock()
response.column_names = ["schema_version"]
response.parsed_rows = [[schema_version]]
return response

def _new_schema_agreement_session(self, schema_versions, distances=None):
hosts = []
distance_map = {}
if distances is None:
distances = [HostDistance.LOCAL] * len(schema_versions)
for index, schema_version in enumerate(schema_versions):
host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4())
host.set_up()
hosts.append(host)
distance_map[host] = distances[index]

cluster = Cluster(protocol_version=4)
for host in hosts:
cluster.metadata.add_or_return_host(host)

session = Session(cluster, hosts)
session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL))
session._pools = {}
for host, schema_version in zip(hosts, schema_versions):
connection = Mock(endpoint=host.endpoint)
connection.wait_for_response.return_value = self._mock_schema_response(schema_version)
session._pools[host] = self.MockPool(host, connection)

return session, hosts

# TODO: this suite could be expanded; for now just adding a test covering a PR
@mock_session_pools
def test_default_serial_consistency_level_ep(self, *_):
Expand Down Expand Up @@ -339,6 +394,104 @@ def test_set_keyspace_escapes_quotes(self, *_):
assert query == 'USE simple_ks', (
"Simple keyspace names should not be quoted, got: %r" % query)

@mock_session_pools
def test_wait_for_schema_agreement_queries_all_local_hosts(self, *_):
session, hosts = self._new_schema_agreement_session(["a", "a"])

assert session.wait_for_schema_agreement(wait_time=1)

for host in hosts:
connection = session._pools[host].connection
connection.wait_for_response.assert_called_once()

@mock_session_pools
def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_):
session, hosts = self._new_schema_agreement_session(["a", "b"])
session._time = self.FakeTime()
second_connection = session._pools[hosts[1]].connection
second_connection.wait_for_response.side_effect = [
self._mock_schema_response("b"),
self._mock_schema_response("a")]

assert session.wait_for_schema_agreement(wait_time=1)
assert second_connection.wait_for_response.call_count == 2
assert session._time.clock == 0.2

@mock_session_pools
def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_):
session, hosts = self._new_schema_agreement_session(["a", "a"])
session._time = self.FakeTime()
busy_connection = session._pools[hosts[1]].connection
busy_connection.wait_for_response.side_effect = [
ConnectionBusy("connection overloaded"),
self._mock_schema_response("a")]

assert session.wait_for_schema_agreement(wait_time=1)
assert busy_connection.wait_for_response.call_count == 2
assert session._time.clock == 0.2

@mock_session_pools
def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_):
session, hosts = self._new_schema_agreement_session(["a"])
session._time = self.FakeTime()

unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4())
unconnected_host.set_up()
session.cluster.metadata.add_or_return_host(unconnected_host)

assert session.wait_for_schema_agreement(wait_time=1)
session._pools[hosts[0]].connection.wait_for_response.assert_called()

@mock_session_pools
@patch('cassandra.cluster.wait_futures')
def test_wait_for_schema_agreement_limits_parallel_queries_to_default(self, mocked_wait_futures, *_):
session, _ = self._new_schema_agreement_session(["a"] * 11)
batch_sizes = []

def submit(fn, host, query, deadline):
future = Future()
future.set_result(fn(host, query, deadline))
return future

def fake_wait(futures, timeout=None, return_when=None):
batch_sizes.append(len(futures))
return set(futures), set()

session.submit = Mock(side_effect=submit)
mocked_wait_futures.side_effect = fake_wait

assert session.wait_for_schema_agreement(wait_time=1)
assert batch_sizes == [10, 1]

@mock_session_pools
def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_):
session, hosts = self._new_schema_agreement_session(
["a", "a", "a"],
distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE])

assert session.wait_for_schema_agreement(wait_time=1, scope='rack')

session._pools[hosts[0]].connection.wait_for_response.assert_called_once()
session._pools[hosts[1]].connection.wait_for_response.assert_not_called()
session._pools[hosts[2]].connection.wait_for_response.assert_not_called()

@mock_session_pools
def test_wait_for_schema_agreement_cluster_scope_queries_all_connected_hosts(self, *_):
session, hosts = self._new_schema_agreement_session(
["a", "a", "a"],
distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE])

assert session.wait_for_schema_agreement(wait_time=1, scope='cluster')

for host in hosts:
session._pools[host].connection.wait_for_response.assert_called_once()

@mock_session_pools
def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_):
session, _ = self._new_schema_agreement_session(["a"])

with pytest.raises(ValueError):
session.wait_for_schema_agreement(wait_time=1, scope='planet')
class ProtocolVersionTests(unittest.TestCase):

def test_protocol_downgrade_test(self):
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
ConnectionBusy, ConnectionException, ConnectionShutdown, DefaultEndPoint,
ShardAwarePortGenerator)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ProtocolHandler, ResultMessage,
Expand Down Expand Up @@ -363,6 +364,19 @@ def test_wait_for_responses_shutdown_includes_last_error(self):
assert "already closed" in error_message
assert "Bad file descriptor" in error_message

def test_wait_for_responses_releases_request_id_when_send_fails(self):
c = self.make_connection()
c._socket_writable = False
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)

with pytest.raises(ConnectionBusy):
c.wait_for_responses(Mock())

assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests


@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped')
class ConnectionHeartbeatTest(unittest.TestCase):
Expand Down
Loading
Loading