Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,19 @@ def __init__(self, connection, owner):
with connection.lock:
if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
request_id = connection.get_request_id()
try:
connection.send_msg(OptionsMessage(), request_id, self._options_callback)
except Exception as exc:
if connection.is_control_connection:
connection.in_flight -= 1
# send_msg() registers the callback before writing to the socket,
# so a write failure must unwind that registration here.
connection._requests.pop(request_id, None)
if request_id not in connection.request_ids:
Comment on lines +1823 to +1828
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just don't understand the second commit. Why do you treat CC differently here? Why do we care about socket write error - won't it result in connection being closed anyway?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC is special because ControlConnection.return_connection() does not decrement in_flight, while HostConnection.return_connection() does. The write failure still matters because send_msg() has already registered the callback and reserved the request id before push(), so that bookkeeping has to be unwound explicitly.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mention return_connection but I don't see where it is called :(

Copy link
Copy Markdown
Collaborator Author

@dkropachev dkropachev May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the later owner.return_connection(connection) in ConnectionHeartbeat.run:

def run(self):
self._shutdown_event.wait(self._interval)
while not self._shutdown_event.is_set():
start_time = time.time()
futures = []
failed_connections = []
try:
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
for connection in connections:
self._raise_if_stopped()
if not (connection.is_defunct or connection.is_closed):
if connection.is_idle:
try:
futures.append(HeartbeatFuture(connection, owner))
except Exception as e:
log.warning("Failed sending heartbeat message on connection (%s) to %s",
id(connection), connection.endpoint)
failed_connections.append((connection, owner, e))
else:
connection.reset_idle()
else:
log.debug("Cannot send heartbeat message on connection (%s) to %s",
id(connection), connection.endpoint)
# make sure the owner sees this defunt/closed connection
owner.return_connection(connection)
self._raise_if_stopped()
# Wait max `self._timeout` seconds for all HeartbeatFutures to complete
timeout = self._timeout
start_time = time.time()
for f in futures:
self._raise_if_stopped()
connection = f.connection
try:
f.wait(timeout)
# TODO: move this, along with connection locks in pool, down into Connection
with connection.lock:
connection.in_flight -= 1
connection.reset_idle()
except Exception as e:
log.warning("Heartbeat failed for connection (%s) to %s",
id(connection), connection.endpoint)
failed_connections.append((f.connection, f.owner, e))
timeout = self._timeout - (time.time() - start_time)
for connection, owner, exc in failed_connections:
self._raise_if_stopped()
if not connection.is_control_connection:
# Only HostConnection supports shutdown_on_error
owner.shutdown_on_error = True
connection.defunct(exc)
owner.return_connection(connection)

for connection, owner, exc in failed_connections:
self._raise_if_stopped()
if not connection.is_control_connection:
# Only HostConnection supports shutdown_on_error
owner.shutdown_on_error = True
connection.defunct(exc)
owner.return_connection(connection)

That block only unwinds the callback/request-id registration that send_msg() already did:

def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None):
if self.is_defunct:
msg = "Connection to %s is defunct" % self.endpoint
if self.last_error:
msg += ": %s" % (self.last_error,)
raise ConnectionShutdown(msg)
elif self.is_closed:
msg = "Connection to %s is closed" % self.endpoint
if self.last_error:
msg += ": %s" % (self.last_error,)
raise ConnectionShutdown(msg)
elif not self._socket_writable:
raise ConnectionBusy("Connection %s is overloaded" % self.endpoint)
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)
if self._is_checksumming_enabled:
buffer = io.BytesIO()
self._segment_codec.encode(buffer, msg)
msg = buffer.getvalue()
self.push(msg)
return len(msg)

For control connections, ControlConnection.return_connection() does not decrement in_flight

def return_connection(self, connection):
if connection is self._connection and (connection.is_defunct or connection.is_closed):
self.reconnect()

while HostConnection.return_connection() does, so the direct decrement has to stay here. It can’t be handled in ControlConnection.return_connection() because that method only sees a defunct/closed connection at the end of the heartbeat cycle

def return_connection(self, connection, stream_was_orphaned=False):
if not stream_was_orphaned:
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()

It does not know which request_id was reserved, and it does not have the context that send_msg() already registered the callback in _requests. The leak happens in the send_msg() failure path, while we still have the request_id and can unwind _requests / request_ids immediately.`

connection.request_ids.append(request_id)
self._exception = exc
self._event.set()
else:
Comment on lines +1819 to 1832
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment for the first commit. Assumption I have in the comment: if self.push(msg) fails then connection is broken.

There is one more edge case: if we wail after self._requests[request_id] = (cb, decoder, result_metadata), but before self.push(msg) then _requests will have the new request, but it won't be sent, effectively making it orphaned, without being accounted as such.
Can we fix that? It would require having a cleanup path for request_id in connection._requests, which I'm afraid is a bit too dangerous.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed

self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
self._event.set()
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cassandra import OperationTimedOut
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
locally_supported_compressions, ConnectionHeartbeat, HeartbeatFuture, _Frame, Timer, TimerManager,
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
Expand Down Expand Up @@ -463,6 +463,31 @@ def test_no_req_ids(self, *args):
holder.return_connection.assert_has_calls(
[call(max_connection)] * get_holders.call_count)

def test_heartbeat_future_releases_request_id_when_send_fails(self, *args):
connection = Connection(DefaultEndPoint('1.2.3.4'))
connection.push = Mock(side_effect=ConnectionException("write failed"))
owner = Mock()
initial_in_flight = connection.in_flight
initial_request_ids = len(connection.request_ids)

# HostConnection.return_connection releases the heartbeat's in-flight slot.
def return_connection(conn):
with conn.lock:
conn.in_flight -= 1

owner.return_connection.side_effect = return_connection

future = HeartbeatFuture(connection, owner)

with pytest.raises(ConnectionException):
future.wait(0)

owner.return_connection(connection)

assert connection.in_flight == initial_in_flight
assert len(connection.request_ids) == initial_request_ids
assert not connection._requests

def test_unexpected_response(self, *args):
request_id = 999

Expand Down
Loading