diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..e3e01cf6b8 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4618,6 +4618,7 @@ def _query(self, host, message=None, cb=None): self._current_host = host connection = None + request_id = None try: # TODO get connectTimeout from cluster settings if self.query: @@ -4642,16 +4643,26 @@ def _query(self, host, message=None, cb=None): except ConnectionBusy as exc: log.debug("Connection for host %s is busy, moving to the next host", host) self._errors[host] = exc + if connection: + self._return_connection_after_send_failure(pool, connection, request_id) except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc if self._metrics is not None: self._metrics.on_connection_error() if connection: - pool.return_connection(connection) + self._return_connection_after_send_failure(pool, connection, request_id) return None + def _return_connection_after_send_failure(self, pool, connection, request_id): + if request_id is not None: + with connection.lock: + connection._requests.pop(request_id, None) + if request_id not in connection.request_ids: + connection.request_ids.append(request_id) + pool.return_connection(connection) + @property def has_more_pages(self): """ diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..7ca1d39838 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -97,6 +97,34 @@ def test_unknown_result_class(self): with pytest.raises(ConnectionException): rf.result() + def test_query_releases_request_id_when_send_fails_after_registration(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Connection('1.2.3.4') + connection.push = Mock(side_effect=ConnectionException("write failed")) + + initial_request_ids = len(connection.request_ids) + request_id = connection.request_ids.popleft() + connection.in_flight += 1 + pool.borrow_connection.return_value = (connection, request_id) + + def return_connection(conn): + with conn.lock: + conn.in_flight -= 1 + + pool.return_connection.side_effect = return_connection + + query = SimpleStatement("SELECT * FROM foo") + message = QueryMessage(query=query.query_string, consistency_level=ConsistencyLevel.ONE) + rf = ResponseFuture(session, message, query, 1) + + assert rf._query('ip1') is None + pool.return_connection.assert_called_once_with(connection) + assert connection.in_flight == 0 + assert len(connection.request_ids) == initial_request_ids + assert request_id in connection.request_ids + assert not connection._requests + def test_set_keyspace_result(self): session = self.make_session() rf = self.make_response_future(session)