diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..f07160e385 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1816,7 +1816,19 @@ def __init__(self, connection, owner): with connection.lock: if connection.in_flight < connection.max_request_id: connection.in_flight += 1 - connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + request_id = connection.get_request_id() + try: + connection.send_msg(OptionsMessage(), request_id, self._options_callback) + except Exception as exc: + if connection.is_control_connection: + connection.in_flight -= 1 + # send_msg() registers the callback before writing to the socket, + # so a write failure must unwind that registration here. + connection._requests.pop(request_id, None) + if request_id not in connection.request_ids: + connection.request_ids.append(request_id) + self._exception = exc + self._event.set() else: self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") self._event.set() diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..cf4607fbed 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -21,7 +21,7 @@ from cassandra import OperationTimedOut from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, - locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + locally_supported_compressions, ConnectionHeartbeat, HeartbeatFuture, _Frame, Timer, TimerManager, ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, @@ -463,6 +463,31 @@ def test_no_req_ids(self, *args): holder.return_connection.assert_has_calls( [call(max_connection)] * get_holders.call_count) + def test_heartbeat_future_releases_request_id_when_send_fails(self, *args): + connection = Connection(DefaultEndPoint('1.2.3.4')) + connection.push = Mock(side_effect=ConnectionException("write failed")) + owner = Mock() + initial_in_flight = connection.in_flight + initial_request_ids = len(connection.request_ids) + + # HostConnection.return_connection releases the heartbeat's in-flight slot. + def return_connection(conn): + with conn.lock: + conn.in_flight -= 1 + + owner.return_connection.side_effect = return_connection + + future = HeartbeatFuture(connection, owner) + + with pytest.raises(ConnectionException): + future.wait(0) + + owner.return_connection(connection) + + assert connection.in_flight == initial_in_flight + assert len(connection.request_ids) == initial_request_ids + assert not connection._requests + def test_unexpected_response(self, *args): request_id = 999