diff --git a/cassandra/pool.py b/cassandra/pool.py index 91b990a979..35cb107ef5 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -785,13 +785,30 @@ def _replace(self, connection): self._session.submit(self._replace, connection) return + stale_endpoint = False with self._lock: if self.is_shutdown: replacement_connection.close() self._is_replacing = False return - self._connections[replacement_connection.features.shard_id] = replacement_connection + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, expected_endpoint)) + if not stale_endpoint: + self._connections[replacement_connection.features.shard_id] = replacement_connection self._is_replacing = False + if stale_endpoint: + log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + replacement_connection.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return with self._stream_available_condition: self._stream_available_condition.notify() @@ -962,11 +979,28 @@ def _open_connection_to_missing_shard(self, shard_id): ) if self._keyspace: conn.set_keyspace_blocking(self._keyspace) - self._connections[conn.features.shard_id] = conn + with self.host.lock: + stale_endpoint = not ( + _endpoints_match( + self._session.cluster, self.host.endpoint, + expected_endpoint) and + _host_is_current_for_endpoint( + self._session.cluster, self.host, + expected_endpoint)) + if not stale_endpoint: + self._connections[conn.features.shard_id] = conn if is_shutdown: conn.close() return + if stale_endpoint: + log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s", + self.host, expected_endpoint) + conn.close() + self._remove_stale_pool(expected_endpoint) + with self._stream_available_condition: + self._stream_available_condition.notify() + return if old_conn is not None: remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 29355bd679..0a7c6d07eb 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -20,7 +20,7 @@ from cassandra.shard_info import _ShardingInfo import unittest -from threading import Thread, Event, Lock +from threading import Thread, Event, Lock, Condition from unittest.mock import Mock, NonCallableMagicMock, MagicMock from cassandra.cluster import Cluster, Session, ShardAwareOptions @@ -454,3 +454,89 @@ def test_replace_retries_when_replacement_keyspace_set_fails(self): submitted_fn, submitted_connection = session.submit.call_args.args assert submitted_fn == pool._replace assert submitted_connection is initial_connection + + def test_replace_discards_replacement_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + initial_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + replacement_connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.side_effect = [ + initial_connection, replacement_connection] + + pool = HostConnection(host, HostDistance.LOCAL, session) + pool._is_replacing = True + + pool._replace(initial_connection) + + replacement_connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert not pool._is_replacing + + def test_missing_shard_discards_connection_when_endpoint_changes_during_keyspace_set(self): + old_endpoint = DefaultEndPoint('127.0.0.1') + new_endpoint = DefaultEndPoint('127.0.0.2') + host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.sharding_info = _ShardingInfo( + shard_id=0, shards_count=1, partitioner='', + sharding_algorithm='', sharding_ignore_msb=0, + shard_aware_port='', shard_aware_port_ssl='') + session = NonCallableMagicMock(spec=Session, keyspace='ks') + session.cluster = MagicMock() + session.cluster.shard_aware_options = ShardAwareOptions() + session.cluster.ssl_options = None + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.remove_pool.return_value = None + connection = HashableMock( + spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False, + features=ProtocolFeatures(shard_id=0)) + connection.set_keyspace_blocking.side_effect = ( + lambda keyspace: setattr(host, 'endpoint', new_endpoint)) + session.cluster.connection_factory.return_value = connection + + pool = HostConnection.__new__(HostConnection) + pool.host = host + pool.endpoint = old_endpoint + pool.host_distance = HostDistance.LOCAL + pool.is_shutdown = False + pool._session = session + pool._lock = Lock() + pool._stream_available_condition = Condition(Lock()) + pool._connections = {} + pool._pending_connections = [] + pool._connecting = {0} + pool._excess_connections = set() + pool._trash = set() + pool._shard_connections_futures = [] + pool._keyspace = 'ks' + pool.advanced_shardaware_block_until = 0 + pool.tablets_routing_v1 = False + + pool._open_connection_to_missing_shard(0) + + connection.close.assert_called_once_with() + session.remove_pool.assert_called_once_with( + host, expected_host=host, expected_endpoint=old_endpoint, + expected_pool=pool) + assert pool._connections == {} + assert pool._connecting == set()