Skip to content

Commit ab38a22

Browse files
committed
cluster: use dynamic whitelist policy for proxy access
1 parent 3aa5935 commit ab38a22

11 files changed

Lines changed: 925 additions & 73 deletions

File tree

cassandra/cluster.py

Lines changed: 221 additions & 45 deletions
Large diffs are not rendered by default.

cassandra/datastax/insights/serializers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def initialize_registry(insights_registry):
3737
DCAwareRoundRobinPolicy,
3838
TokenAwarePolicy,
3939
WhiteListRoundRobinPolicy,
40+
DynamicWhiteListRoundRobinPolicy,
4041
HostFilterPolicy,
4142
ConstantReconnectionPolicy,
4243
ExponentialReconnectionPolicy,
@@ -80,6 +81,13 @@ def whitelist_round_robin_policy_insights_serializer(policy):
8081
'options': {'allowed_hosts': policy._allowed_hosts}
8182
}
8283

84+
@insights_registry.register_serializer_for(DynamicWhiteListRoundRobinPolicy)
85+
def dynamic_whitelist_round_robin_policy_insights_serializer(policy):
86+
return {'type': policy.__class__.__name__,
87+
'namespace': namespace(policy.__class__),
88+
'options': {'allowed_host_ids': tuple(str(host_id) for host_id in policy._allowed_host_ids)}
89+
}
90+
8391
@insights_registry.register_serializer_for(HostFilterPolicy)
8492
def host_filter_policy_insights_serializer(policy):
8593
return {

cassandra/policies.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,16 @@ def check_supported(self):
166166
"""
167167
pass
168168

169+
def on_control_connection_host(self, host):
170+
"""
171+
Called when the control connection resolves the metadata host behind
172+
the endpoint it is currently using.
173+
174+
Policies that maintain a dynamic host allowlist can override this to
175+
update their internal view of the cluster.
176+
"""
177+
pass
178+
169179

170180
class RoundRobinPolicy(LoadBalancingPolicy):
171181
"""
@@ -540,6 +550,9 @@ def on_add(self, *args, **kwargs):
540550
def on_remove(self, *args, **kwargs):
541551
return self._child_policy.on_remove(*args, **kwargs)
542552

553+
def on_control_connection_host(self, host):
554+
return self._child_policy.on_control_connection_host(host)
555+
543556

544557
class WhiteListRoundRobinPolicy(RoundRobinPolicy):
545558
"""
@@ -594,6 +607,58 @@ def on_add(self, host):
594607
RoundRobinPolicy.on_add(self, host)
595608

596609

610+
class DynamicWhiteListRoundRobinPolicy(RoundRobinPolicy):
611+
"""
612+
A :class:`.RoundRobinPolicy` variant whose allowlist is updated from the
613+
control connection.
614+
615+
This is intended for proxy deployments where the driver can only reach the
616+
host currently behind the control connection endpoint. The policy keeps
617+
every other discovered node at :attr:`~.HostDistance.IGNORED` until the
618+
control connection resolves a different host.
619+
"""
620+
621+
def __init__(self):
622+
self._allowed_host_ids = frozenset(())
623+
self._cluster = None
624+
RoundRobinPolicy.__init__(self)
625+
626+
def _host_is_allowed(self, host):
627+
return getattr(host, "host_id", None) in self._allowed_host_ids
628+
629+
def _refresh_live_hosts(self, hosts):
630+
self._live_hosts = frozenset(
631+
host for host in hosts
632+
if self._host_is_allowed(host) and host.is_up is not False
633+
)
634+
635+
def populate(self, cluster, hosts):
636+
self._cluster = cluster
637+
self._refresh_live_hosts(hosts)
638+
if len(self._live_hosts) > 1:
639+
self._position = randint(0, len(self._live_hosts) - 1)
640+
else:
641+
self._position = 0
642+
643+
def distance(self, host):
644+
return HostDistance.LOCAL if self._host_is_allowed(host) else HostDistance.IGNORED
645+
646+
def on_up(self, host):
647+
if self._host_is_allowed(host):
648+
RoundRobinPolicy.on_up(self, host)
649+
650+
def on_add(self, host):
651+
if self._host_is_allowed(host):
652+
RoundRobinPolicy.on_add(self, host)
653+
654+
def on_control_connection_host(self, host):
655+
with self._hosts_lock:
656+
allowed_host_id = getattr(host, "host_id", None)
657+
self._allowed_host_ids = frozenset((allowed_host_id,)) if allowed_host_id is not None else frozenset(())
658+
if self._cluster is not None:
659+
self._refresh_live_hosts(self._cluster.metadata.all_hosts())
660+
661+
597662
class HostFilterPolicy(LoadBalancingPolicy):
598663
"""
599664
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
@@ -654,6 +719,9 @@ def on_add(self, host, *args, **kwargs):
654719
def on_remove(self, host, *args, **kwargs):
655720
return self._child_policy.on_remove(host, *args, **kwargs)
656721

722+
def on_control_connection_host(self, host):
723+
return self._child_policy.on_control_connection_host(host)
724+
657725
@property
658726
def predicate(self):
659727
"""
@@ -1322,6 +1390,9 @@ def on_add(self, *args, **kwargs):
13221390
def on_remove(self, *args, **kwargs):
13231391
return self._child_policy.on_remove(*args, **kwargs)
13241392

1393+
def on_control_connection_host(self, host):
1394+
return self._child_policy.on_control_connection_host(host)
1395+
13251396

13261397
class DefaultLoadBalancingPolicy(WrapperPolicy):
13271398
"""

cassandra/pool.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
435435

436436
if self._keyspace:
437437
first_connection.set_keyspace_blocking(self._keyspace)
438-
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
438+
if first_connection.features.sharding_info and not self._session.is_shard_aware_disabled():
439439
self.host.sharding_info = first_connection.features.sharding_info
440440
self._open_connections_for_all_shards(first_connection.features.shard_id)
441441
self.tablets_routing_v1 = first_connection.features.tablets_routing_v1
@@ -451,7 +451,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
451451
raise NoConnectionsAvailable()
452452

453453
shard_id = None
454-
if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key:
454+
if not self._session.is_shard_aware_disabled() and self.host.sharding_info and routing_key:
455455
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
456456

457457
shard_id = None
@@ -554,15 +554,15 @@ def return_connection(self, connection, stream_was_orphaned=False):
554554
if not connection.signaled_error:
555555
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
556556
"marking host %s as down", id(connection), self.host)
557-
is_down = self.host.signal_connection_failure(connection.last_error)
557+
is_down = self._session._signal_connection_failure(self.host, connection.last_error)
558558
connection.signaled_error = True
559559

560560
if self.shutdown_on_error and not is_down:
561561
is_down = True
562562

563563
if is_down:
564564
self.shutdown()
565-
self._session.cluster.on_down(self.host, is_host_addition=False)
565+
self._session._handle_pool_down(self.host, is_host_addition=False)
566566
else:
567567
connection.close()
568568
with self._lock:
@@ -603,7 +603,7 @@ def _replace(self, connection):
603603
try:
604604
if connection.features.shard_id in self._connections:
605605
del self._connections[connection.features.shard_id]
606-
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
606+
if self.host.sharding_info and not self._session.is_shard_aware_disabled():
607607
self._connecting.add(connection.features.shard_id)
608608
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
609609
else:
@@ -678,7 +678,8 @@ def disable_advanced_shard_aware(self, secs):
678678

679679
def _get_shard_aware_endpoint(self):
680680
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \
681-
self._session.cluster.shard_aware_options.disable_shardaware_port:
681+
self._session.cluster.shard_aware_options.disable_shardaware_port or \
682+
self._session.is_shard_aware_disabled():
682683
return None
683684

684685
endpoint = None
@@ -920,5 +921,3 @@ def open_count(self):
920921
@property
921922
def _excess_connection_limit(self):
922923
return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier
923-
924-

docs/api/cassandra/policies.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ Load Balancing
2626
.. autoclass:: WhiteListRoundRobinPolicy
2727
:members:
2828

29+
.. autoclass:: DynamicWhiteListRoundRobinPolicy
30+
:members:
31+
2932
.. autoclass:: TokenAwarePolicy
3033
:members:
3134

tests/integration/standard/test_client_routes.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040

4141
from cassandra.cluster import Cluster
4242
from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy
43-
from cassandra.connection import ClientRoutesEndPoint
44-
from cassandra.policies import RoundRobinPolicy
43+
from cassandra.connection import ClientRoutesEndPoint, ConnectionException
44+
from cassandra.policies import DynamicWhiteListRoundRobinPolicy, RoundRobinPolicy
4545
from tests.integration import (
4646
TestCluster,
4747
get_cluster,
@@ -54,6 +54,28 @@
5454

5555
log = logging.getLogger(__name__)
5656

57+
58+
class ProxyOnlyReachableConnection(Cluster.connection_class):
59+
"""
60+
Simulates a private-link client that can reach only the proxy endpoint.
61+
62+
The CCM node addresses are reachable from the local test runner, which means
63+
the existing client-routes tests cannot reproduce bugs that only appear when
64+
direct node IPs are private. This connection class rejects those direct node
65+
addresses while still allowing the NLB address.
66+
"""
67+
68+
@classmethod
69+
def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs):
70+
address, _ = endpoint.resolve()
71+
if address.startswith("127.0.0."):
72+
raise ConnectionException(
73+
"Simulated private node address %s is unreachable from the client" % address,
74+
endpoint=endpoint,
75+
)
76+
return super().factory(endpoint, timeout, host_conn=host_conn, *args, **kwargs)
77+
78+
5779
class TcpProxy:
5880
"""
5981
A simple TCP proxy that forwards connections from a local listen port
@@ -535,6 +557,57 @@ def teardown_module():
535557
else:
536558
os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts
537559

560+
561+
class TestProxyConnectivityWithoutClientRoutes(unittest.TestCase):
562+
"""
563+
Reproducer for connecting through a generic proxy when node addresses are
564+
not reachable from the client.
565+
566+
The initial control connection can reach the cluster through the proxy, but
567+
the driver later tries to open pools to the discovered node addresses
568+
directly. In a proxy-only environment that makes connect/query fail.
569+
"""
570+
571+
@classmethod
572+
def setUpClass(cls):
573+
cls.node_addrs = {
574+
1: "127.0.0.1",
575+
2: "127.0.0.2",
576+
3: "127.0.0.3",
577+
}
578+
cls.proxy_node_id = 1
579+
cls.nlb = NLBEmulator()
580+
cls.nlb.start(cls.node_addrs)
581+
582+
@classmethod
583+
def tearDownClass(cls):
584+
cls.nlb.stop()
585+
586+
def _make_proxy_cluster(self):
587+
return Cluster(
588+
contact_points=[NLBEmulator.LISTEN_HOST],
589+
port=self.nlb.node_port(self.proxy_node_id),
590+
connection_class=ProxyOnlyReachableConnection,
591+
load_balancing_policy=DynamicWhiteListRoundRobinPolicy(),
592+
)
593+
594+
def test_dynamic_whitelist_session_succeeds_when_only_proxy_is_reachable(self):
595+
cluster = self._make_proxy_cluster()
596+
self.addCleanup(cluster.shutdown)
597+
598+
session = cluster.connect()
599+
row = session.execute(
600+
"SELECT release_version FROM system.local WHERE key='local'"
601+
).one()
602+
603+
self.assertIsNotNone(row)
604+
pool_state = session.get_pool_state()
605+
self.assertEqual(len(pool_state), 1)
606+
607+
session_host = next(iter(pool_state))
608+
self.assertEqual(session_host.endpoint.address, NLBEmulator.LISTEN_HOST)
609+
self.assertEqual(session_host.endpoint.port, self.nlb.node_port(self.proxy_node_id))
610+
538611
@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported',
539612
scylla_version="2026.1.0")
540613
class TestGetHostPortMapping(unittest.TestCase):

tests/unit/advanced/test_insights.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import logging
1919
import sys
20+
import uuid
2021
from unittest.mock import sentinel
2122

2223
from cassandra import ConsistencyLevel
@@ -37,6 +38,7 @@
3738
DCAwareRoundRobinPolicy,
3839
TokenAwarePolicy,
3940
WhiteListRoundRobinPolicy,
41+
DynamicWhiteListRoundRobinPolicy,
4042
HostFilterPolicy,
4143
ConstantReconnectionPolicy,
4244
ExponentialReconnectionPolicy,
@@ -203,6 +205,14 @@ def test_whitelist_round_robin_policy(self):
203205
'options': {'allowed_hosts': ('127.0.0.3',)},
204206
'type': 'WhiteListRoundRobinPolicy'}
205207

208+
def test_dynamic_whitelist_round_robin_policy(self):
209+
policy = DynamicWhiteListRoundRobinPolicy()
210+
host_id = uuid.uuid4()
211+
policy._allowed_host_ids = (host_id,)
212+
assert insights_registry.serialize(policy) == {'namespace': 'cassandra.policies',
213+
'options': {'allowed_host_ids': (str(host_id),)},
214+
'type': 'DynamicWhiteListRoundRobinPolicy'}
215+
206216
def test_host_filter_policy(self):
207217
def my_predicate(s):
208218
return False

0 commit comments

Comments
 (0)