diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..7d5ee0f47b 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1219,15 +1219,19 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, - allow_beta_protocol_version=self.allow_beta_protocol_version) + try: + msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version) - if self._is_checksumming_enabled: - buffer = io.BytesIO() - self._segment_codec.encode(buffer, msg) - msg = buffer.getvalue() + if self._is_checksumming_enabled: + buffer = io.BytesIO() + self._segment_codec.encode(buffer, msg) + msg = buffer.getvalue() - self.push(msg) + self.push(msg) + except Exception: + self._requests.pop(request_id, None) + raise return len(msg) def wait_for_response(self, msg, timeout=None, **kwargs): @@ -1262,9 +1266,16 @@ def wait_for_responses(self, *msgs, **kwargs): self.in_flight += available for i, request_id in enumerate(request_ids): - self.send_msg(msgs[messages_sent + i], - request_id, - partial(waiter.got_response, index=messages_sent + i)) + try: + self.send_msg(msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i)) + except Exception: + unsent_request_ids = request_ids[i:] + with self.lock: + self.in_flight -= len(unsent_request_ids) + self.request_ids.extend(unsent_request_ids) + raise messages_sent += available if messages_sent == len(msgs): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..5e2d57192e 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -18,14 +18,14 @@ from threading import Lock from unittest.mock import Mock, ANY, call, patch -from cassandra import OperationTimedOut +from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, ConnectionBusy, DefaultEndPoint, ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, - SupportedMessage, ProtocolHandler, ResultMessage, + SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage, RESULT_KIND_SET_KEYSPACE) from tests.util import wait_until, assertRegex @@ -363,6 +363,31 @@ def test_wait_for_responses_shutdown_includes_last_error(self): assert "already closed" in error_message assert "Bad file descriptor" in error_message + def test_wait_for_responses_releases_request_id_when_send_fails(self): + c = self.make_connection() + c._socket_writable = False + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionBusy): + c.wait_for_responses(Mock()) + + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests + + def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self): + c = self.make_connection() + c.push = Mock(side_effect=ConnectionException("write failed")) + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionException): + c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE)) + + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase):