Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,13 +785,30 @@ def _replace(self, connection):
self._session.submit(self._replace, connection)
return

stale_endpoint = False
with self._lock:
if self.is_shutdown:
replacement_connection.close()
self._is_replacing = False
return
self._connections[replacement_connection.features.shard_id] = replacement_connection
with self.host.lock:
stale_endpoint = not (
_endpoints_match(
self._session.cluster, self.host.endpoint,
expected_endpoint) and
_host_is_current_for_endpoint(
self._session.cluster, self.host, expected_endpoint))
if not stale_endpoint:
self._connections[replacement_connection.features.shard_id] = replacement_connection
self._is_replacing = False
if stale_endpoint:
log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s",
self.host, expected_endpoint)
replacement_connection.close()
self._remove_stale_pool(expected_endpoint)
with self._stream_available_condition:
self._stream_available_condition.notify()
return
with self._stream_available_condition:
self._stream_available_condition.notify()

Expand Down Expand Up @@ -962,11 +979,28 @@ def _open_connection_to_missing_shard(self, shard_id):
)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connections[conn.features.shard_id] = conn
with self.host.lock:
stale_endpoint = not (
_endpoints_match(
self._session.cluster, self.host.endpoint,
expected_endpoint) and
_host_is_current_for_endpoint(
self._session.cluster, self.host,
expected_endpoint))
if not stale_endpoint:
self._connections[conn.features.shard_id] = conn

if is_shutdown:
conn.close()
return
if stale_endpoint:
log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s",
self.host, expected_endpoint)
conn.close()
self._remove_stale_pool(expected_endpoint)
with self._stream_available_condition:
self._stream_available_condition.notify()
return

if old_conn is not None:
remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids)
Expand Down
88 changes: 87 additions & 1 deletion tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cassandra.shard_info import _ShardingInfo

import unittest
from threading import Thread, Event, Lock
from threading import Thread, Event, Lock, Condition
from unittest.mock import Mock, NonCallableMagicMock, MagicMock

from cassandra.cluster import Cluster, Session, ShardAwareOptions
Expand Down Expand Up @@ -454,3 +454,89 @@ def test_replace_retries_when_replacement_keyspace_set_fails(self):
submitted_fn, submitted_connection = session.submit.call_args.args
assert submitted_fn == pool._replace
assert submitted_connection is initial_connection

def test_replace_discards_replacement_when_endpoint_changes_during_keyspace_set(self):
old_endpoint = DefaultEndPoint('127.0.0.1')
new_endpoint = DefaultEndPoint('127.0.0.2')
host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4())
session = NonCallableMagicMock(spec=Session, keyspace='ks')
session.cluster = MagicMock()
session.cluster.shard_aware_options = ShardAwareOptions()
session.cluster._endpoints_match.side_effect = Cluster._endpoints_match
session.remove_pool.return_value = None
initial_connection = HashableMock(
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
max_request_id=100, signaled_error=False,
orphaned_threshold_reached=False,
features=ProtocolFeatures(shard_id=0))
replacement_connection = HashableMock(
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
max_request_id=100, signaled_error=False,
orphaned_threshold_reached=False,
features=ProtocolFeatures(shard_id=0))
replacement_connection.set_keyspace_blocking.side_effect = (
lambda keyspace: setattr(host, 'endpoint', new_endpoint))
session.cluster.connection_factory.side_effect = [
initial_connection, replacement_connection]

pool = HostConnection(host, HostDistance.LOCAL, session)
pool._is_replacing = True

pool._replace(initial_connection)

replacement_connection.close.assert_called_once_with()
session.remove_pool.assert_called_once_with(
host, expected_host=host, expected_endpoint=old_endpoint,
expected_pool=pool)
assert pool._connections == {}
assert not pool._is_replacing

def test_missing_shard_discards_connection_when_endpoint_changes_during_keyspace_set(self):
old_endpoint = DefaultEndPoint('127.0.0.1')
new_endpoint = DefaultEndPoint('127.0.0.2')
host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4())
host.sharding_info = _ShardingInfo(
shard_id=0, shards_count=1, partitioner='',
sharding_algorithm='', sharding_ignore_msb=0,
shard_aware_port='', shard_aware_port_ssl='')
session = NonCallableMagicMock(spec=Session, keyspace='ks')
session.cluster = MagicMock()
session.cluster.shard_aware_options = ShardAwareOptions()
session.cluster.ssl_options = None
session.cluster._endpoints_match.side_effect = Cluster._endpoints_match
session.remove_pool.return_value = None
connection = HashableMock(
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
max_request_id=100, signaled_error=False,
orphaned_threshold_reached=False,
features=ProtocolFeatures(shard_id=0))
connection.set_keyspace_blocking.side_effect = (
lambda keyspace: setattr(host, 'endpoint', new_endpoint))
session.cluster.connection_factory.return_value = connection

pool = HostConnection.__new__(HostConnection)
pool.host = host
pool.endpoint = old_endpoint
pool.host_distance = HostDistance.LOCAL
pool.is_shutdown = False
pool._session = session
pool._lock = Lock()
pool._stream_available_condition = Condition(Lock())
pool._connections = {}
pool._pending_connections = []
pool._connecting = {0}
pool._excess_connections = set()
pool._trash = set()
pool._shard_connections_futures = []
pool._keyspace = 'ks'
pool.advanced_shardaware_block_until = 0
pool.tablets_routing_v1 = False

pool._open_connection_to_missing_shard(0)

connection.close.assert_called_once_with()
session.remove_pool.assert_called_once_with(
host, expected_host=host, expected_endpoint=old_endpoint,
expected_pool=pool)
assert pool._connections == {}
assert pool._connecting == set()
Loading