Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4618,6 +4618,7 @@ def _query(self, host, message=None, cb=None):
self._current_host = host

connection = None
request_id = None
try:
# TODO get connectTimeout from cluster settings
if self.query:
Expand All @@ -4642,16 +4643,26 @@ def _query(self, host, message=None, cb=None):
except ConnectionBusy as exc:
log.debug("Connection for host %s is busy, moving to the next host", host)
self._errors[host] = exc
if connection:
self._return_connection_after_send_failure(pool, connection, request_id)
except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc
if self._metrics is not None:
self._metrics.on_connection_error()
if connection:
pool.return_connection(connection)
self._return_connection_after_send_failure(pool, connection, request_id)

return None

def _return_connection_after_send_failure(self, pool, connection, request_id):
if request_id is not None:
with connection.lock:
connection._requests.pop(request_id, None)
if request_id not in connection.request_ids:
connection.request_ids.append(request_id)
pool.return_connection(connection)

@property
def has_more_pages(self):
"""
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_response_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ def test_unknown_result_class(self):
with pytest.raises(ConnectionException):
rf.result()

def test_query_releases_request_id_when_send_fails_after_registration(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Connection('1.2.3.4')
connection.push = Mock(side_effect=ConnectionException("write failed"))

initial_request_ids = len(connection.request_ids)
request_id = connection.request_ids.popleft()
connection.in_flight += 1
pool.borrow_connection.return_value = (connection, request_id)

def return_connection(conn):
with conn.lock:
conn.in_flight -= 1

pool.return_connection.side_effect = return_connection

query = SimpleStatement("SELECT * FROM foo")
message = QueryMessage(query=query.query_string, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query, 1)

assert rf._query('ip1') is None
pool.return_connection.assert_called_once_with(connection)
assert connection.in_flight == 0
assert len(connection.request_ids) == initial_request_ids
assert request_id in connection.request_ids
assert not connection._requests

def test_set_keyspace_result(self):
session = self.make_session()
rf = self.make_response_future(session)
Expand Down
Loading