Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,15 +1219,19 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)
try:
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)

if self._is_checksumming_enabled:
buffer = io.BytesIO()
self._segment_codec.encode(buffer, msg)
msg = buffer.getvalue()
if self._is_checksumming_enabled:
buffer = io.BytesIO()
self._segment_codec.encode(buffer, msg)
msg = buffer.getvalue()

self.push(msg)
self.push(msg)
except Exception:
self._requests.pop(request_id, None)
raise
return len(msg)

def wait_for_response(self, msg, timeout=None, **kwargs):
Expand Down Expand Up @@ -1262,9 +1266,16 @@ def wait_for_responses(self, *msgs, **kwargs):
self.in_flight += available

for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
try:
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
except Exception:
unsent_request_ids = request_ids[i:]
with self.lock:
self.in_flight -= len(unsent_request_ids)
self.request_ids.extend(unsent_request_ids)
raise
messages_sent += available
Comment on lines 1222 to 1279
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are now multiple PRs regarding request_ids, in_flight etc.
It is incredible that we need to ever worry about this stuff.
Why is it even responsibility of the caller to adjust those values?
Connection should have a method for sending request. This method should be responsible for managing in_flight, request_ids and other state of Connection. Callers should never worry about that.

This is the only sane solution, and anything else will just require fixing callsites forever.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes I know this is a code in connection. But you also have PRs for e.g. hearbeats. Heartbeats should never need to touch this stuff.

Copy link
Copy Markdown
Collaborator Author

@dkropachev dkropachev May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. I removed the async keyspace cleanup from this branch, so this PR is now scoped to the concrete send-failure leak only. The broader connection-level helper/refactor can stay as a separate follow-up.


if messages_sent == len(msgs):
Expand Down Expand Up @@ -1734,7 +1745,13 @@ def process_result(result):
# acquire a new request id
request_id = self.get_request_id()

self.send_msg(query, request_id, process_result)
try:
self.send_msg(query, request_id, process_result)
except Exception as exc:
with self.lock:
if request_id not in self._requests and request_id not in self.request_ids:
self.request_ids.append(request_id)
callback(self, exc)

@property
def is_idle(self):
Expand Down
52 changes: 49 additions & 3 deletions tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from threading import Lock
from unittest.mock import Mock, ANY, call, patch

from cassandra import OperationTimedOut
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
ConnectionException, ConnectionShutdown, ConnectionBusy, DefaultEndPoint, ShardAwarePortGenerator)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ProtocolHandler, ResultMessage,
SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage,
RESULT_KIND_SET_KEYSPACE)

from tests.util import wait_until, assertRegex
Expand Down Expand Up @@ -363,6 +363,52 @@ def test_wait_for_responses_shutdown_includes_last_error(self):
assert "already closed" in error_message
assert "Bad file descriptor" in error_message

def test_wait_for_responses_releases_request_id_when_send_fails(self):
c = self.make_connection()
c._socket_writable = False
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)

with pytest.raises(ConnectionBusy):
c.wait_for_responses(Mock())

assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests

def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self):
c = self.make_connection()
c.push = Mock(side_effect=ConnectionException("write failed"))
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)

with pytest.raises(ConnectionException):
c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE))

assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests

def test_set_keyspace_async_reports_send_failure_and_releases_request_id(self):
c = self.make_connection()
c.push = Mock(side_effect=ConnectionException("write failed"))
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)
callback_errors = []

def callback(conn, error):
callback_errors.append(error)
with conn.lock:
conn.in_flight -= 1

c.set_keyspace_async("ks", callback)

assert len(callback_errors) == 1
assert isinstance(callback_errors[0], ConnectionException)
assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests


@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped')
class ConnectionHeartbeatTest(unittest.TestCase):
Expand Down
Loading