Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,7 +2020,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
Intended for internal use only.
"""
if self.is_shutdown:
return
return False

with host.lock:
was_up = host.is_up
Expand All @@ -2035,14 +2035,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
if pool_state:
connected |= pool_state['open_count'] > 0
if connected:
return
return False

host.set_down()
if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting():
return
return False
log.warning("Host %s has been marked down", host)

self.on_down_potentially_blocking(host, is_host_addition)
return True

def on_add(self, host, refresh_nodes=True):
if self.is_shutdown:
Expand Down Expand Up @@ -2134,8 +2135,8 @@ def on_remove(self, host):
def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False):
is_down = host.signal_connection_failure(connection_exc)
if is_down:
self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down
return self.on_down(host, is_host_addition, expect_host_to_be_down)
return False

def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None):
"""
Expand Down Expand Up @@ -4226,9 +4227,10 @@ def _signal_error(self):
# host may be None if it's already been removed, but that indicates
# that errors have already been reported, so we're fine
if host:
self._cluster.signal_connection_failure(
is_down = self._cluster.signal_connection_failure(
host, self._connection.last_error, is_host_addition=False)
return
if is_down:
return

# if the connection is not defunct or the host already left, reconnect
# manually
Expand Down
29 changes: 27 additions & 2 deletions tests/unit/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

import unittest
import uuid

from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, ANY, call

from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.pool import Host
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionException
from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
ConstantReconnectionPolicy, IdentityTranslator)

Expand Down Expand Up @@ -301,6 +302,30 @@ def test_wait_for_schema_agreement_none_timeout(self):
cc._time = self.time
assert cc.wait_for_schema_agreement()

def test_signal_error_reconnects_when_host_down_signal_is_discounted(self):
cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4)
self.addCleanup(cluster.shutdown)

host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4())
host.set_up()
cluster.metadata.add_or_return_host(host)

session = Mock()
session.get_pool_state.return_value = {host: {"open_count": 1}}
cluster.sessions.add(session)

connection_error = ConnectionException("control connection failed", endpoint=host.endpoint)
cluster.control_connection._connection = Mock(
endpoint=host.endpoint,
is_defunct=True,
last_error=connection_error)
cluster.control_connection.reconnect = Mock()

cluster.control_connection._signal_error()

assert host.is_up is True
cluster.control_connection.reconnect.assert_called_once_with()

def test_refresh_nodes_and_tokens(self):
self.control_connection.refresh_node_list_and_token_map()
meta = self.cluster.metadata
Expand Down
Loading