Skip to content

Commit fbb89e5

Browse files
committed
client_routes: separate API methods for complete refresh vs change events
Split _query_routes into two specialized methods: 1. _query_all_routes_for_connections: Used during initialization and reconnection for complete refresh of all routes for configured connection IDs 2. _query_routes_for_change_event: Used for CLIENT_ROUTES_CHANGE event handling with proper AND logic for related connection_ids/host_ids 3. _execute_routes_query: Common helper for query execution and result parsing with proxy address override support This separation clarifies the different use cases and semantics: - Complete refresh: loads all routes for our connection IDs - Change events: queries specific routes affected by the event Updated unit tests to use the appropriate method mocks. All tests pass.
1 parent aee428d commit fbb89e5

2 files changed

Lines changed: 102 additions & 73 deletions

File tree

cassandra/client_routes.py

Lines changed: 94 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def initialize(self, control_connection: 'ControlConnection') -> None:
242242
log.info("[client routes] Initializing with %d proxies", len(self.config.proxies))
243243

244244
try:
245-
connection_ids = self._connection_ids
246-
routes = self._query_routes(control_connection, connection_ids=connection_ids)
245+
connection_ids = list(self._connection_ids)
246+
routes = self._query_all_routes_for_connections(control_connection, connection_ids)
247247

248248
self._routes.update(routes)
249249
except Exception as e:
@@ -264,30 +264,36 @@ def handle_client_routes_change(self, control_connection: 'ControlConnection',
264264
:param connection_ids: Affected connection ID strings; empty means all.
265265
:param host_ids: Affected host ID strings; empty means all.
266266
"""
267-
filtered_conn_ids = None
268-
if connection_ids:
269-
configured_ids = self._connection_ids
270-
filtered = [cid for cid in connection_ids if cid in configured_ids]
271-
if not filtered:
267+
268+
try:
269+
# Both arrays must be present and same length, or both empty
270+
271+
full_refresh = False
272+
if not connection_ids and not host_ids:
273+
log.error(
274+
"[client routes] CLIENT_ROUTES_CHANGE has no connection_ids and host_ids, doing full refresh")
275+
full_refresh = True
276+
elif len(connection_ids) != len(host_ids):
277+
log.error("[client routes] CLIENT_ROUTES_CHANGE has mismatched lengths (conn: %d, host: %d), doing full refresh",
278+
len(connection_ids), len(host_ids))
279+
full_refresh = True
280+
281+
if full_refresh:
282+
routes = self._query_all_routes_for_connections(control_connection, list(self._connection_ids))
283+
self._routes.update(routes)
272284
return
273-
filtered_conn_ids = filtered
274285

275-
host_uuids = [uuid.UUID(hid) for hid in host_ids] if host_ids else None
286+
host_uuids = [uuid.UUID(hid) for hid in host_ids]
287+
pairs = [(cid, hid) for cid, hid in zip(connection_ids, host_uuids)
288+
if cid in self._connection_ids]
276289

277-
try:
278-
routes = self._query_routes(
279-
control_connection,
280-
connection_ids=filtered_conn_ids,
281-
host_ids=host_uuids
282-
)
283-
except Exception as e:
284-
log.warning("[client routes] Failed to query routes for CLIENT_ROUTES_CHANGE: %s", e, exc_info=True)
285-
return
290+
if not pairs:
291+
return # No relevant connection IDs
286292

287-
if host_uuids:
293+
routes = self._query_routes_for_change_event(control_connection, pairs)
288294
self._routes.merge(routes, affected_host_ids=set(host_uuids))
289-
else:
290-
self._routes.update(routes)
295+
except Exception as e:
296+
log.warning("[client routes] Failed to handle CLIENT_ROUTES_CHANGE: %s", e, exc_info=True)
291297

292298
def handle_control_connection_reconnect(self, control_connection: 'ControlConnection') -> None:
293299
"""
@@ -307,57 +313,80 @@ def handle_control_connection_reconnect(self, control_connection: 'ControlConnec
307313
"""
308314
log.info("[client routes] Control connection reconnected, reloading all routes")
309315

310-
connection_ids = self._connection_ids
311-
routes = self._query_routes(control_connection, connection_ids=connection_ids)
316+
connection_ids = list(self._connection_ids)
317+
routes = self._query_all_routes_for_connections(control_connection, connection_ids)
312318
self._routes.update(routes)
313319

314-
def _query_routes(self, control_connection: 'ControlConnection', connection_ids: Optional[List[str]] = None,
315-
host_ids: Optional[List[uuid.UUID]] = None) -> List[_Route]:
320+
def _query_all_routes_for_connections(self, control_connection: 'ControlConnection',
321+
connection_ids: List[str]) -> List[_Route]:
316322
"""
317-
Query system.client_routes table.
318-
319-
When both connection_ids and host_ids are provided, uses AND logic:
320-
returns routes where connection_id IN (...) AND host_id IN (...).
321-
This matches the semantics of CLIENT_ROUTES_CHANGE events which contain
322-
related connection_ids and host_ids from the same operation.
323-
323+
Query all routes for the given connection IDs (complete refresh).
324+
325+
Used when control connection reconnects or as a fallback when
326+
CLIENT_ROUTES_CHANGE event has malformed data.
327+
324328
:param control_connection: ControlConnection to execute query
325-
:param connection_ids: Optional list of connection ID strings to filter by
326-
:param host_ids: Optional list of host UUIDs to filter by
329+
:param connection_ids: List of connection ID strings
327330
:return: List of _Route
328331
"""
329-
query = "SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes"
330-
where_clauses = []
332+
if not connection_ids:
333+
return []
334+
335+
# Deduplicate connection_ids to avoid redundant parameters
336+
unique_connection_ids = list(dict.fromkeys(connection_ids))
337+
placeholders = ', '.join('?' for _ in unique_connection_ids)
338+
query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE connection_id IN ({placeholders})"
339+
params = [cid.encode('utf-8') for cid in unique_connection_ids]
340+
341+
log.debug("[client routes] Querying all routes for connection_ids=%s", unique_connection_ids)
342+
return self._execute_routes_query(control_connection, query, params)
343+
344+
def _query_routes_for_change_event(self, control_connection: 'ControlConnection',
345+
route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]:
346+
"""
347+
Query specific routes affected by a CLIENT_ROUTES_CHANGE event.
348+
349+
Takes a list of (connection_id, host_id) pairs that represent the exact
350+
routes affected by an operation. This provides precise updates without
351+
fetching unrelated routes.
352+
353+
If the pairs list is empty or None, falls back to a complete refresh
354+
of all routes for safety.
355+
356+
:param control_connection: ControlConnection to execute query
357+
:param route_pairs: List of (connection_id, host_id) tuples
358+
:return: List of _Route
359+
"""
360+
# Deduplicate pairs while preserving order
361+
unique_pairs = list(dict.fromkeys(route_pairs))
362+
363+
# Build query with compound WHERE clause: WHERE (connection_id, host_id) IN ((?, ?), (?, ?), ...)
364+
placeholders = ', '.join('(?, ?)' for _ in unique_pairs)
365+
query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE (connection_id, host_id) IN ({placeholders})"
366+
367+
# Flatten pairs into params list: [conn_id1, host_id1, conn_id2, host_id2, ...]
331368
params = []
332-
333-
if connection_ids:
334-
# Deduplicate connection_ids to avoid redundant parameters
335-
unique_connection_ids = list(dict.fromkeys(connection_ids))
336-
placeholders = ', '.join('?' for _ in unique_connection_ids)
337-
where_clauses.append("connection_id IN (%s)" % placeholders)
338-
params.extend(cid.encode('utf-8') for cid in unique_connection_ids)
339-
340-
if host_ids:
341-
# Deduplicate host_ids to avoid redundant parameters
342-
unique_host_ids = list(dict.fromkeys(host_ids))
343-
placeholders = ', '.join('?' for _ in unique_host_ids)
344-
where_clauses.append("host_id IN (%s)" % placeholders)
345-
params.extend(hid.bytes for hid in unique_host_ids)
346-
347-
if where_clauses:
348-
query += " WHERE " + " AND ".join(where_clauses)
349-
else:
350-
query += " ALLOW FILTERING"
351-
352-
# Log the query for debugging CLIENT_ROUTES_CHANGE event handling
353-
if connection_ids and host_ids:
354-
log.debug("[client routes] Querying with AND logic: connection_ids=%s, host_ids=%s",
355-
connection_ids, host_ids)
356-
elif connection_ids:
357-
log.debug("[client routes] Querying by connection_ids=%s", connection_ids)
358-
elif host_ids:
359-
log.debug("[client routes] Querying by host_ids=%s", host_ids)
360-
369+
for conn_id, host_id in unique_pairs:
370+
params.append(conn_id.encode('utf-8'))
371+
params.append(host_id.bytes)
372+
373+
log.debug("[client routes] Querying specific route pairs from CLIENT_ROUTES_CHANGE: %s", unique_pairs[:5]) # Log first 5 pairs for debugging
374+
return self._execute_routes_query(control_connection, query, params)
375+
376+
def _execute_routes_query(self, control_connection: 'ControlConnection',
377+
query: str, params: List) -> List[_Route]:
378+
"""
379+
Execute a routes query and parse results.
380+
381+
Common helper for both complete refresh and change event queries.
382+
383+
:param control_connection: ControlConnection to execute query
384+
:param query: CQL query string
385+
:param params: Query parameters
386+
:return: List of _Route
387+
"""
388+
log.debug("[client routes] Executing query: %s with %d parameters", query, len(params))
389+
361390
query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE,
362391
query_params=params if params else None)
363392
result = control_connection._connection.wait_for_response(

tests/unit/test_client_routes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_handler_initialization(self):
165165
self.assertIsNotNone(handler)
166166
self.assertEqual(handler.ssl_enabled, False)
167167

168-
@patch.object(_ClientRoutesHandler, '_query_routes')
168+
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
169169
def test_initialize(self, mock_query):
170170
host_id = uuid.uuid4()
171171
mock_query.return_value = [
@@ -187,7 +187,7 @@ def test_initialize(self, mock_query):
187187
self.assertIsNotNone(route)
188188
self.assertEqual(route.address, "node1.example.com")
189189

190-
@patch.object(_ClientRoutesHandler, '_query_routes')
190+
@patch.object(_ClientRoutesHandler, '_query_routes_for_change_event')
191191
def test_handle_change_filters_by_configured_connection_ids(self, mock_query):
192192
"""Events with unrelated connection_ids should be ignored."""
193193
handler = _ClientRoutesHandler(self.config)
@@ -203,7 +203,7 @@ def test_handle_change_filters_by_configured_connection_ids(self, mock_query):
203203
)
204204
mock_query.assert_not_called()
205205

206-
@patch.object(_ClientRoutesHandler, '_query_routes')
206+
@patch.object(_ClientRoutesHandler, '_query_routes_for_change_event')
207207
def test_handle_change_merges_when_host_ids_present(self, mock_query):
208208
"""When host_ids are provided, routes should be merged (not full replace)."""
209209
handler = _ClientRoutesHandler(self.config)
@@ -233,7 +233,7 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query):
233233
self.assertIsNotNone(handler._routes.get_by_host_id(existing_host))
234234
self.assertIsNotNone(handler._routes.get_by_host_id(new_host))
235235

236-
@patch.object(_ClientRoutesHandler, '_query_routes')
236+
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
237237
def test_handle_change_updates_when_no_host_ids(self, mock_query):
238238
"""When no host_ids are provided, routes should be fully replaced."""
239239
handler = _ClientRoutesHandler(self.config)
@@ -261,7 +261,7 @@ def test_handle_change_updates_when_no_host_ids(self, mock_query):
261261
self.assertIsNone(handler._routes.get_by_host_id(old_host))
262262
self.assertIsNotNone(handler._routes.get_by_host_id(new_host))
263263

264-
@patch.object(_ClientRoutesHandler, '_query_routes')
264+
@patch.object(_ClientRoutesHandler, '_query_routes_for_change_event')
265265
def test_handle_change_swallows_query_failure(self, mock_query):
266266
"""If _query_routes raises, handle_client_routes_change should not propagate."""
267267
handler = _ClientRoutesHandler(self.config)
@@ -276,7 +276,7 @@ def test_handle_change_swallows_query_failure(self, mock_query):
276276
host_ids=[str(uuid.uuid4())],
277277
)
278278

279-
@patch.object(_ClientRoutesHandler, '_query_routes')
279+
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
280280
def test_reconnect_propagates_exception_on_failure(self, mock_query):
281281
"""handle_control_connection_reconnect should propagate exceptions to caller."""
282282
handler = _ClientRoutesHandler(self.config)
@@ -288,7 +288,7 @@ def test_reconnect_propagates_exception_on_failure(self, mock_query):
288288
self.assertIn("query failed", str(ctx.exception))
289289
self.assertEqual(mock_query.call_count, 1)
290290

291-
@patch.object(_ClientRoutesHandler, '_query_routes')
291+
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
292292
def test_reconnect_keeps_old_routes_on_failure(self, mock_query):
293293
"""On failure, existing routes must be preserved (critical for PL clusters)."""
294294
handler = _ClientRoutesHandler(self.config)
@@ -307,7 +307,7 @@ def test_reconnect_keeps_old_routes_on_failure(self, mock_query):
307307
# Old route must still be there
308308
self.assertIsNotNone(handler._routes.get_by_host_id(host_id))
309309

310-
@patch.object(_ClientRoutesHandler, '_query_routes')
310+
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
311311
def test_reconnect_updates_routes_on_success(self, mock_query):
312312
"""handle_control_connection_reconnect should update routes on success."""
313313
handler = _ClientRoutesHandler(self.config)

0 commit comments

Comments
 (0)