diff --git a/changelog.d/19005.feature b/changelog.d/19005.feature new file mode 100644 index 00000000000..811d2e31af8 --- /dev/null +++ b/changelog.d/19005.feature @@ -0,0 +1 @@ +Add experimental support for MSC4360: Sliding Sync Threads Extension. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 7a8f546d6bf..86620bda33e 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -272,6 +272,9 @@ class EventContentFields: M_TOPIC: Final = "m.topic" M_TEXT: Final = "m.text" + # Event relations + RELATIONS: Final = "m.relates_to" + class EventUnsignedContentFields: """Fields found inside the 'unsigned' data on events""" @@ -360,3 +363,10 @@ class Direction(enum.Enum): class ProfileFields: DISPLAYNAME: Final = "displayname" AVATAR_URL: Final = "avatar_url" + + +class MRelatesToFields: + """Fields found inside m.relates_to content blocks.""" + + EVENT_ID: Final = "event_id" + REL_TYPE: Final = "rel_type" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 52c3ec0da29..13f5e1fcd88 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -593,3 +593,6 @@ def read_config( # MSC4306: Thread Subscriptions # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + + # MSC4360: Threads Extension to Sliding Sync + self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index fd38ffa920f..8513e897115 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -105,8 +105,6 @@ async def get_relations( ) -> JsonDict: """Get related events of a event, ordered by topological ordering. - TODO Accept a PaginationConfig instead of individual pagination parameters. - Args: requester: The user requesting the relations. event_id: Fetch events that relate to this event ID. diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 6a5d5c7b3cc..748f4854892 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -305,6 +305,7 @@ async def handle_room(room_id: str) -> None: # account data, read receipts, typing indicators, to-device messages, etc). actual_room_ids=set(relevant_room_map.keys()), actual_room_response_map=rooms, + room_membership_for_user_at_to_token_map=room_membership_for_user_map, from_token=from_token, to_token=to_token, ) diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index d076bec51a9..d62f2d675f6 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -14,6 +14,7 @@ import itertools import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, AbstractSet, @@ -26,16 +27,28 @@ from typing_extensions import TypeAlias, assert_never -from synapse.api.constants import AccountDataTypes, EduTypes +from synapse.api.constants import ( + AccountDataTypes, + EduTypes, + EventContentFields, + Membership, + MRelatesToFields, + RelationTypes, +) +from synapse.events import EventBase from synapse.handlers.receipts import ReceiptEventSource +from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.logging.opentracing import trace from synapse.storage.databases.main.receipts import ReceiptInRoom +from synapse.storage.databases.main.relations import ThreadUpdateInfo from synapse.types import ( DeviceListUpdates, JsonMapping, MultiWriterStreamToken, + RoomStreamToken, SlidingSyncStreamToken, StrCollection, + StreamKeyType, StreamToken, ThreadSubscriptionsToken, ) @@ -51,6 +64,7 @@ concurrently_execute, gather_optional_coroutines, ) +from synapse.visibility import filter_events_for_client _ThreadSubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription @@ -58,6 +72,7 @@ _ThreadUnsubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription ) +_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate if TYPE_CHECKING: from synapse.server import HomeServer @@ -73,7 +88,10 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self.relations_handler = hs.get_relations_handler() + self._storage_controllers = hs.get_storage_controllers() self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled + self._enable_threads_ext = hs.config.experimental.msc4360_enabled @trace async def get_extensions_response( @@ -84,6 +102,7 @@ async def get_extensions_response( actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], actual_room_ids: set[str], actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType], to_token: StreamToken, from_token: SlidingSyncStreamToken | None, ) -> SlidingSyncResult.Extensions: @@ -99,6 +118,8 @@ async def get_extensions_response( actual_room_ids: The actual room IDs in the the Sliding Sync response. actual_room_response_map: A map of room ID to room results in the the Sliding Sync response. + room_membership_for_user_at_to_token_map: A map of room ID to the membership + information for the user in the room at the time of `to_token`. to_token: The latest point in the stream to sync up to. from_token: The point in the stream to sync from. """ @@ -174,6 +195,18 @@ async def get_extensions_response( from_token=from_token, ) + threads_coro = None + if sync_config.extensions.threads is not None and self._enable_threads_ext: + threads_coro = self.get_threads_extension_response( + sync_config=sync_config, + threads_request=sync_config.extensions.threads, + actual_room_ids=actual_room_ids, + actual_room_response_map=actual_room_response_map, + room_membership_for_user_at_to_token_map=room_membership_for_user_at_to_token_map, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, @@ -181,6 +214,7 @@ async def get_extensions_response( receipts_response, typing_response, thread_subs_response, + threads_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, @@ -188,6 +222,7 @@ async def get_extensions_response( receipts_coro, typing_coro, thread_subs_coro, + threads_coro, ) return SlidingSyncResult.Extensions( @@ -197,6 +232,7 @@ async def get_extensions_response( receipts=receipts_response, typing=typing_response, thread_subscriptions=thread_subs_response, + threads=threads_response, ) def find_relevant_room_ids_for_extension( @@ -967,3 +1003,273 @@ async def get_thread_subscriptions_extension_response( unsubscribed=unsubscribed_threads, prev_batch=prev_batch, ) + + def _extract_thread_id_from_event(self, event: EventBase) -> str | None: + """Extract thread ID from event if it's a thread reply. + + Args: + event: The event to check. + + Returns: + The thread ID if the event is a thread reply, None otherwise. + """ + relates_to = event.content.get(EventContentFields.RELATIONS) + if isinstance(relates_to, dict): + if relates_to.get(MRelatesToFields.REL_TYPE) == RelationTypes.THREAD: + return relates_to.get(MRelatesToFields.EVENT_ID) + return None + + def _find_threads_in_timeline( + self, + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + ) -> set[str]: + """Find all thread IDs that have events in room timelines. + + Args: + actual_room_response_map: A map of room ID to room results. + + Returns: + A set of thread IDs (thread root event IDs) that appear in the timeline. + """ + threads_in_timeline: set[str] = set() + for room_result in actual_room_response_map.values(): + if room_result.timeline_events: + for event in room_result.timeline_events: + thread_id = self._extract_thread_id_from_event(event) + if thread_id: + threads_in_timeline.add(thread_id) + return threads_in_timeline + + def _merge_prev_batch_token( + self, + current_token: StreamToken | None, + new_token: StreamToken | None, + ) -> StreamToken | None: + """Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination. + + Args: + current_token: The current prev_batch token (may be None) + new_token: The new prev_batch token to merge (may be None) + + Returns: + The merged token (maximum of the two, or None if both are None) + """ + if new_token is None: + return current_token + if current_token is None: + return new_token + if new_token.room_key.stream > current_token.room_key.stream: + return new_token + return current_token + + def _merge_thread_updates( + self, + target: dict[str, list[ThreadUpdateInfo]], + source: dict[str, list[ThreadUpdateInfo]], + ) -> None: + """Merge thread updates from source into target. + + Args: + target: The target dict to merge into (modified in place) + source: The source dict to merge from + """ + for thread_id, updates in source.items(): + target.setdefault(thread_id, []).extend(updates) + + async def get_threads_extension_response( + self, + sync_config: SlidingSyncConfig, + threads_request: SlidingSyncConfig.Extensions.ThreadsExtension, + actual_room_ids: set[str], + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + room_membership_for_user_at_to_token_map: Mapping[str, RoomsForUserType], + to_token: StreamToken, + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.ThreadsExtension | None: + """Handle Threads extension (MSC4360) + + Args: + sync_config: Sync configuration. + threads_request: The threads extension from the request. + actual_room_ids: The actual room IDs in the the Sliding Sync response. + actual_room_response_map: A map of room ID to room results in the + sliding sync response. Used to determine which threads already have + events in the room timeline. + room_membership_for_user_at_to_token_map: A map of room ID to the membership + information for the user in the room at the time of `to_token`. + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + + Returns: + the response (None if empty or threads extension is disabled) + """ + if not threads_request.enabled: + return None + + # Identify which threads already have events in the room timelines. + # If include_roots=False, we'll exclude these threads from the DB query + # since the client already sees the thread activity in the timeline. + # If include_roots=True, we fetch all threads regardless, because the client + # wants the thread root events. + threads_to_exclude: set[str] | None = None + if not threads_request.include_roots: + threads_to_exclude = self._find_threads_in_timeline( + actual_room_response_map + ) + + # Separate rooms into groups based on membership status. + # For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events + # that occurred after the user left/was banned. + leave_ban_rooms: set[str] = set() + other_rooms: set[str] = set() + + for room_id in actual_room_ids: + membership_info = room_membership_for_user_at_to_token_map.get(room_id) + if membership_info and membership_info.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + leave_ban_rooms.add(room_id) + else: + other_rooms.add(room_id) + + # Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks. + all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {} + prev_batch_token: StreamToken | None = None + remaining_limit = threads_request.limit + + # Query for rooms where the user has left or been banned, using their leave/ban + # event position as the upper bound to prevent seeing events after they left. + if leave_ban_rooms: + for room_id in leave_ban_rooms: + if remaining_limit <= 0: + # We've already fetched enough updates, but we still need to set + # prev_batch to indicate there are more results. + prev_batch_token = to_token + break + + membership_info = room_membership_for_user_at_to_token_map[room_id] + bounded_to_token = membership_info.event_pos.to_room_stream_token() + + ( + room_thread_updates, + room_prev_batch, + ) = await self.store.get_thread_updates_for_rooms( + room_ids={room_id}, + from_token=from_token.stream_token.room_key if from_token else None, + to_token=bounded_to_token, + limit=remaining_limit, + exclude_thread_ids=threads_to_exclude, + ) + + # Count how many updates we fetched and reduce the remaining limit + num_updates = sum( + len(updates) for updates in room_thread_updates.values() + ) + remaining_limit -= num_updates + + self._merge_thread_updates(all_thread_updates, room_thread_updates) + prev_batch_token = self._merge_prev_batch_token( + prev_batch_token, room_prev_batch + ) + + # Query for rooms where the user is joined, invited, or knocking, using the + # normal to_token as the upper bound. + if other_rooms and remaining_limit > 0: + ( + other_thread_updates, + other_prev_batch, + ) = await self.store.get_thread_updates_for_rooms( + room_ids=other_rooms, + from_token=from_token.stream_token.room_key if from_token else None, + to_token=to_token.room_key, + limit=remaining_limit, + exclude_thread_ids=threads_to_exclude, + ) + + self._merge_thread_updates(all_thread_updates, other_thread_updates) + prev_batch_token = self._merge_prev_batch_token( + prev_batch_token, other_prev_batch + ) + + if len(all_thread_updates) == 0: + return None + + # Build a mapping of event_id -> (thread_id, update) for efficient lookup + # during visibility filtering. + event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {} + for thread_id, updates in all_thread_updates.items(): + for update in updates: + event_to_thread_map[update.event_id] = (thread_id, update) + + # Fetch and filter events for visibility + all_events = await self.store.get_events_as_list(event_to_thread_map.keys()) + filtered_events = await filter_events_for_client( + self._storage_controllers, sync_config.user.to_string(), all_events + ) + + # Rebuild thread updates from filtered events + filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list) + for event in filtered_events: + if event.event_id in event_to_thread_map: + thread_id, update = event_to_thread_map[event.event_id] + filtered_updates[thread_id].append(update) + + if not filtered_updates: + return None + + # Note: Updates are already sorted by stream_ordering DESC from the database query, + # and filter_events_for_client preserves order, so updates[0] is guaranteed to be + # the latest event for each thread. + + # Optionally fetch thread root events and their bundled aggregations + thread_root_event_map = {} + aggregations_map = {} + if threads_request.include_roots: + thread_root_events = await self.store.get_events_as_list( + filtered_updates.keys() + ) + thread_root_event_map = {e.event_id: e for e in thread_root_events} + + if thread_root_event_map: + aggregations_map = ( + await self.relations_handler.get_bundled_aggregations( + thread_root_event_map.values(), + sync_config.user.to_string(), + ) + ) + + thread_updates: dict[str, dict[str, _ThreadUpdate]] = {} + for thread_root, updates in filtered_updates.items(): + # We only care about the latest update for the thread. + # After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering). + latest_update = updates[0] + + # Generate per-thread prev_batch token if this thread has multiple visible updates. + # When we hit the global limit, we generate prev_batch tokens for all threads, even if + # we only saw 1 update for them. This is to cover the case where we only saw + # a single update for a given thread, but the global limit prevents us from + # obtaining other updates which would have otherwise been included in the + # range. + per_thread_prev_batch = None + if len(updates) > 1 or prev_batch_token is not None: + # Create a token pointing to one position before the latest event's stream position. + # This makes it exclusive - /relations with dir=b won't return the latest event again. + # Use StreamToken.START as base (all other streams at 0) since only room position matters. + per_thread_prev_batch = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, + RoomStreamToken(stream=latest_update.stream_ordering - 1), + ) + + thread_updates.setdefault(latest_update.room_id, {})[thread_root] = ( + _ThreadUpdate( + thread_root=thread_root_event_map.get(thread_root), + prev_batch=per_thread_prev_batch, + bundled_aggregations=aggregations_map.get(thread_root), + ) + ) + + return SlidingSyncResult.Extensions.ThreadsExtension( + updates=thread_updates, + prev_batch=prev_batch_token, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 458bf08a19f..b02ac8e4e18 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -31,6 +31,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( + EventClientSerializer, SerializeEventConfig, format_event_for_client_v2_without_room_id, format_event_raw, @@ -56,6 +57,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.rest.admin.experimental_features import ExperimentalFeature +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken from synapse.types.rest.client import SlidingSyncBody from synapse.util.caches.lrucache import LruCache @@ -646,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet): - receipts (MSC3960) - account data (MSC3959) - thread subscriptions (MSC4308) + - threads (MSC4360) Request query parameters: timeout: How long to wait for new events in milliseconds. @@ -849,7 +852,10 @@ async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: logger.info("Client has disconnected; not serializing response.") return 200, {} - response_content = await self.encode_response(requester, sliding_sync_results) + time_now = self.clock.time_msec() + response_content = await self.encode_response( + requester, sliding_sync_results, time_now + ) return 200, response_content @@ -858,6 +864,7 @@ async def encode_response( self, requester: Requester, sliding_sync_result: SlidingSyncResult, + time_now: int, ) -> JsonDict: response: JsonDict = defaultdict(dict) @@ -866,10 +873,10 @@ async def encode_response( if serialized_lists: response["lists"] = serialized_lists response["rooms"] = await self.encode_rooms( - requester, sliding_sync_result.rooms + requester, sliding_sync_result.rooms, time_now ) response["extensions"] = await self.encode_extensions( - requester, sliding_sync_result.extensions + requester, sliding_sync_result.extensions, time_now ) return response @@ -901,9 +908,8 @@ async def encode_rooms( self, requester: Requester, rooms: dict[str, SlidingSyncResult.RoomResult], + time_now: int, ) -> JsonDict: - time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig( event_format=format_event_for_client_v2_without_room_id, requester=requester, @@ -1019,7 +1025,10 @@ async def encode_rooms( @trace_with_opname("sliding_sync.encode_extensions") async def encode_extensions( - self, requester: Requester, extensions: SlidingSyncResult.Extensions + self, + requester: Requester, + extensions: SlidingSyncResult.Extensions, + time_now: int, ) -> JsonDict: serialized_extensions: JsonDict = {} @@ -1089,6 +1098,17 @@ async def encode_extensions( _serialise_thread_subscriptions(extensions.thread_subscriptions) ) + # excludes both None and falsy `threads` + if extensions.threads: + serialized_extensions[ + "io.element.msc4360.threads" + ] = await _serialise_threads( + self.event_serializer, + time_now, + extensions.threads, + self.store, + ) + return serialized_extensions @@ -1125,6 +1145,72 @@ def _serialise_thread_subscriptions( return out +async def _serialise_threads( + event_serializer: EventClientSerializer, + time_now: int, + threads: SlidingSyncResult.Extensions.ThreadsExtension, + store: "DataStore", +) -> JsonDict: + """ + Serialize the threads extension response for sliding sync. + + Args: + event_serializer: The event serializer to use for serializing thread root events. + time_now: The current time in milliseconds, used for event serialization. + threads: The threads extension data containing thread updates and pagination tokens. + store: The datastore, needed for serializing stream tokens. + + Returns: + A JSON-serializable dict containing: + - "updates": A nested dict mapping room_id -> thread_root_id -> thread update. + Each thread update may contain: + - "thread_root": The serialized thread root event (if include_roots was True), + with bundled aggregations including the latest_event in unsigned.m.relations.m.thread. + - "prev_batch": A pagination token for fetching older events in the thread. + - "prev_batch": A pagination token for fetching older thread updates (if available). + """ + out: JsonDict = {} + + if threads.updates: + updates_dict: JsonDict = {} + for room_id, thread_updates in threads.updates.items(): + room_updates: JsonDict = {} + for thread_root_id, update in thread_updates.items(): + # Serialize the update + update_dict: JsonDict = {} + + # Serialize the thread_root event if present + if update.thread_root is not None: + # Create a mapping of event_id to bundled_aggregations + bundle_aggs_map = ( + {thread_root_id: update.bundled_aggregations} + if update.bundled_aggregations + else None + ) + serialized_events = await event_serializer.serialize_events( + [update.thread_root], + time_now, + bundle_aggregations=bundle_aggs_map, + ) + if serialized_events: + update_dict["thread_root"] = serialized_events[0] + + # Add prev_batch if present + if update.prev_batch is not None: + update_dict["prev_batch"] = await update.prev_batch.to_string(store) + + room_updates[thread_root_id] = update_dict + + updates_dict[room_id] = room_updates + + out["updates"] = updates_dict + + if threads.prev_batch: + out["prev_batch"] = await threads.prev_batch.to_string(store) + + return out + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 9d9c37e2a41..c367c8a0718 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -19,6 +19,7 @@ # import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -40,13 +41,19 @@ LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.stream import ( generate_next_token, generate_pagination_bounds, generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, StreamKeyType, StreamToken +from synapse.types import ( + JsonDict, + RoomStreamToken, + StreamKeyType, + StreamToken, +) from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -88,7 +95,23 @@ class _RelatedEvent: sender: str -class RelationsWorkerStore(SQLBaseStore): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadUpdateInfo: + """ + Information about a thread update for the sliding sync threads extension. + + Attributes: + event_id: The event ID of the event in the thread. + room_id: The room ID where this thread exists. + stream_ordering: The stream ordering of this event. + """ + + event_id: str + room_id: str + stream_ordering: int + + +class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -584,14 +607,18 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> dict[str, str]: "get_applicable_edits", _get_applicable_edits_txn ) - edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined] + edits = await self.get_events(edit_ids.values()) # Map to the original event IDs to the edit events. # # There might not be an edit event due to there being no edits or # due to the event not being known, either case is treated the same. return { - original_event_id: edits.get(edit_ids.get(original_event_id)) + original_event_id: ( + edits.get(edit_id) + if (edit_id := edit_ids.get(original_event_id)) + else None + ) for original_event_id in event_ids } @@ -699,7 +726,7 @@ def _get_thread_summaries_txn( "get_thread_summaries", _get_thread_summaries_txn ) - latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + latest_events = await self.get_events(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -1111,6 +1138,148 @@ def _get_related_thread_id(txn: LoggingTransaction) -> str: "get_related_thread_id", _get_related_thread_id ) + async def get_thread_updates_for_rooms( + self, + *, + room_ids: Collection[str], + from_token: RoomStreamToken | None = None, + to_token: RoomStreamToken | None = None, + limit: int = 5, + exclude_thread_ids: Collection[str] | None = None, + ) -> tuple[dict[str, list[ThreadUpdateInfo]], StreamToken | None]: + """Get a list of updated threads, ordered by stream ordering of their + latest reply, filtered to only include threads in rooms where the user + is currently joined. + + Args: + room_ids: The room IDs to fetch thread updates for. + from_token: The lower bound (exclusive) for thread updates. If None, + fetch from the start of the room timeline. + to_token: The upper bound (inclusive) for thread updates. If None, + fetch up to the current position in the room timeline. + limit: Maximum number of thread updates to return. + exclude_thread_ids: Optional collection of thread root event IDs to exclude + from the results. Useful for filtering out threads already visible + in the room timeline. + + Returns: + A tuple of: + A dict mapping thread_id to list of ThreadUpdateInfo objects, + ordered by stream_ordering descending (most recent first). + A prev_batch StreamToken (exclusive) if there are more results available, + None otherwise. + """ + # Ensure bad limits aren't being passed in. + assert limit > 0 + + if len(room_ids) == 0: + return ({}), None + + def _get_thread_updates_for_user_txn( + txn: LoggingTransaction, + ) -> tuple[list[tuple[str, str, str, int]], int | None]: + room_clause, room_id_values = make_in_list_sql_clause( + txn.database_engine, "e.room_id", room_ids + ) + + # Generate the pagination clause, if necessary. + pagination_clause = "" + pagination_args: list[str] = [] + if from_token: + from_bound = from_token.stream + pagination_clause += " AND stream_ordering > ?" + pagination_args.append(str(from_bound)) + + if to_token: + to_bound = to_token.stream + pagination_clause += " AND stream_ordering <= ?" + pagination_args.append(str(to_bound)) + + # Generate the exclusion clause for thread IDs, if necessary. + exclusion_clause = "" + exclusion_args: list[str] = [] + if exclude_thread_ids: + exclusion_clause, exclusion_args = make_in_list_sql_clause( + txn.database_engine, + "er.relates_to_id", + exclude_thread_ids, + negative=True, + ) + exclusion_clause = f" AND {exclusion_clause}" + + # TODO: improve the fact that multiple hits for the same thread means we + # won't get as many overall updates for the sss response + + # Find any thread events between the stream ordering bounds. + sql = f""" + SELECT e.event_id, er.relates_to_id, e.room_id, e.stream_ordering + FROM event_relations AS er + INNER JOIN events AS e ON er.event_id = e.event_id + WHERE er.relation_type = '{RelationTypes.THREAD}' + AND {room_clause} + {exclusion_clause} + {pagination_clause} + ORDER BY stream_ordering DESC + LIMIT ? + """ + + # Fetch `limit + 1` rows as a way to detect if there are more results beyond + # what we're returning. If we get exactly `limit + 1` rows back, we know there + # are more results available and we can set `next_token`. We only return the + # first `limit` rows to the caller. This avoids needing a separate COUNT query. + txn.execute( + sql, + ( + *room_id_values, + *exclusion_args, + *pagination_args, + limit + 1, + ), + ) + + # SQL returns: event_id, thread_id, room_id, stream_ordering + rows = cast(list[tuple[str, str, str, int]], txn.fetchall()) + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(rows) > limit: + # Set the next_token to be the second last row in the result set since + # that will be the last row we return from this function. + # This works as an exclusive bound that can be backpaginated from. + # Use the stream_ordering field (index 2 in original rows) + next_token = rows[-2][3] + + return rows[:limit], next_token + + thread_infos, next_token_int = await self.db_pool.runInteraction( + "get_thread_updates_for_user", _get_thread_updates_for_user_txn + ) + + # Convert the next_token int (stream ordering) to a StreamToken. + # Use StreamToken.START as base (all other streams at 0) since only room + # position matters. + # Subtract 1 to make it exclusive - the client can paginate from this point without + # receiving the last thread update that was already returned. + next_token = None + if next_token_int is not None: + next_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1) + ) + + # Build ThreadUpdateInfo objects. + thread_update_infos: dict[str, list[ThreadUpdateInfo]] = defaultdict(list) + for event_id, thread_id, room_id, stream_ordering in thread_infos: + thread_update_infos[thread_id].append( + ThreadUpdateInfo( + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + ) + ) + + return (thread_update_infos, next_token) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql b/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql new file mode 100644 index 00000000000..b566c6e5461 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql @@ -0,0 +1,33 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Add indexes to improve performance of the thread_updates endpoint and +-- sliding sync threads extension (MSC4360). + +-- Index for efficiently finding all events that relate to a specific event +-- (e.g., all replies to a thread root). This is used by the correlated subquery +-- in get_thread_updates_for_user that counts thread updates. +-- Also useful for other relation queries (edits, reactions, etc.). +CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type + ON event_relations(relates_to_id, relation_type); + +-- Index for the /thread_updates endpoint's cross-room query. +-- Allows efficient descending ordering and range filtering of threads +-- by stream_ordering across all rooms. +CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc + ON threads(stream_ordering DESC); + +-- Index for the EXISTS clause that filters threads to only joined rooms. +-- Allows efficient lookup of a user's current room memberships. +CREATE INDEX IF NOT EXISTS local_current_membership_user_room + ON local_current_membership(user_id, membership, room_id); diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 52688a8b6bc..5b59dd6145d 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -57,19 +57,41 @@ async def from_request( from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") + # Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format + def extract_stream_token(token_str: str) -> str: + """ + Extract the StreamToken portion from a token string. + + Handles both: + - StreamToken format: "s123_456_..." + - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) + + This allows clients using sliding sync to use their pos tokens + with endpoints like /relations and /messages. + """ + if "/" in token_str: + # SlidingSyncStreamToken format: "connection_position/stream_token" + # Split and return just the stream_token part + parts = token_str.split("/", 1) + if len(parts) == 2: + return parts[1] + return token_str + try: from_tok = None if from_tok_str == "END": from_tok = None # For backwards compat. elif from_tok_str: - from_tok = await StreamToken.from_string(store, from_tok_str) + stream_token_str = extract_stream_token(from_tok_str) + from_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: to_tok = None if to_tok_str: - to_tok = await StreamToken.from_string(store, to_tok_str) + stream_token_str = extract_stream_token(to_tok_str) + to_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 494e3570d05..a5d90252b76 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -35,6 +35,9 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase + +if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.types import ( DeviceListUpdates, JsonDict, @@ -388,12 +391,60 @@ def __bool__(self) -> bool: or bool(self.prev_batch) ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadsExtension: + """The Threads extension (MSC4360) + + Provides thread updates for threads that have new activity across all of the + user's joined rooms within the sync window. + + Attributes: + updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate. + Each ThreadUpdate contains information about a thread that has new activity, + including the thread root event (if requested) and a pagination token + for fetching older events in that specific thread. + prev_batch: A pagination token for fetching more thread updates across all rooms. + If present, indicates there are more thread updates available beyond what + was returned in this response. This token can be used with a future request + to paginate through older thread updates. + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadUpdate: + """Information about a single thread that has new activity. + + Attributes: + thread_root: The thread root event, if requested via include_roots in the + request. This is the event that started the thread. + prev_batch: A pagination token (exclusive) for fetching older events in this + specific thread. Only present if the thread has multiple updates in the + sync window. This token can be used with the /relations endpoint with + dir=b to paginate backwards through the thread's history. + bundled_aggregations: Bundled aggregations for the thread root event, + including the latest_event in the thread (found in + unsigned.m.relations.m.thread). Only present if thread_root is included. + """ + + thread_root: EventBase | None + prev_batch: StreamToken | None + bundled_aggregations: "BundledAggregations | None" = None + + def __bool__(self) -> bool: + return bool(self.thread_root) or bool(self.prev_batch) + + updates: Mapping[str, Mapping[str, ThreadUpdate]] | None + prev_batch: StreamToken | None + + def __bool__(self) -> bool: + return bool(self.updates) or bool(self.prev_batch) + to_device: ToDeviceExtension | None = None e2ee: E2eeExtension | None = None account_data: AccountDataExtension | None = None receipts: ReceiptsExtension | None = None typing: TypingExtension | None = None thread_subscriptions: ThreadSubscriptionsExtension | None = None + threads: ThreadsExtension | None = None def __bool__(self) -> bool: return bool( @@ -403,6 +454,7 @@ def __bool__(self) -> bool: or self.receipts or self.typing or self.thread_subscriptions + or self.threads ) next_pos: SlidingSyncStreamToken @@ -852,6 +904,7 @@ class PerConnectionState: Attributes: rooms: The status of each room for the events stream. receipts: The status of each room for the receipts stream. + account_data: The status of each room for the account data stream. room_configs: Map from room_id to the `RoomSyncConfig` of all rooms that we have previously sent down. """ diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 49782b52348..ac2e6cecd66 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -383,6 +383,19 @@ class ThreadSubscriptionsExtension(RequestBodyModel): enabled: StrictBool | None = False limit: StrictInt = 100 + class ThreadsExtension(RequestBodyModel): + """The Threads extension (MSC4360) + + Attributes: + enabled: Whether the threads extension is enabled. + include_roots: whether to include thread root events in the extension response. + limit: maximum number of thread updates to return across all joined rooms. + """ + + enabled: StrictBool | None = False + include_roots: StrictBool = False + limit: StrictInt = 100 + to_device: ToDeviceExtension | None = None e2ee: E2eeExtension | None = None account_data: AccountDataExtension | None = None @@ -391,6 +404,9 @@ class ThreadSubscriptionsExtension(RequestBodyModel): thread_subscriptions: ThreadSubscriptionsExtension | None = Field( None, alias="io.element.msc4308.thread_subscriptions" ) + threads: ThreadsExtension | None = Field( + None, alias="io.element.msc4360.threads" + ) conn_id: StrictStr | None = None lists: ( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 825fb10acfd..678ed103be9 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -340,6 +340,7 @@ async def yieldable_gather_results_delaying_cancellation( T4 = TypeVar("T4") T5 = TypeVar("T5") T6 = TypeVar("T6") +T7 = TypeVar("T7") @overload @@ -469,6 +470,30 @@ async def gather_optional_coroutines( ) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + tuple[ + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, + Coroutine[Any, Any, T4] | None, + Coroutine[Any, Any, T5] | None, + Coroutine[Any, Any, T6] | None, + Coroutine[Any, Any, T7] | None, + ] + ], +) -> tuple[ + T1 | None, + T2 | None, + T3 | None, + T4 | None, + T5 | None, + T6 | None, + T7 | None, +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]], ) -> tuple[T1 | None, ...]: diff --git a/tests/rest/client/sliding_sync/test_extension_threads.py b/tests/rest/client/sliding_sync/test_extension_threads.py new file mode 100644 index 00000000000..cfbc3a2155b --- /dev/null +++ b/tests/rest/client/sliding_sync/test_extension_threads.py @@ -0,0 +1,1165 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.constants import RelationTypes +from synapse.rest.client import login, relations, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict, StreamKeyType +from synapse.util.clock import Clock + +from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase +from tests.server import TimedOutException + +logger = logging.getLogger(__name__) + + +# The name of the extension. Currently unstable-prefixed. +EXT_NAME = "io.element.msc4360.threads" + + +class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase): + """ + Test the threads extension in the Sliding Sync API. + """ + + maxDiff = None + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + relations.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4360_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + super().prepare(reactor, clock, hs) + + def test_no_data_initial_sync(self) -> None: + """ + Test enabling threads extension during initial sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertNotIn(EXT_NAME, response_body["extensions"]) + + def test_no_data_incremental_sync(self) -> None: + """ + Test enabling threads extension during incremental sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + initial_sync_body: JsonDict = {} + + # Initial sync + response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok) + + # Incremental sync with extension enabled + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert + self.assertNotIn( + EXT_NAME, + response_body["extensions"], + response_body, + ) + + def test_threads_initial_sync(self) -> None: + """ + Test threads appear in initial sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + _latest_event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": user1_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + )["event_id"] + + # # get the baseline stream_id of the thread_subscriptions stream + # # before we write any data. + # # Required because the initial value differs between SQLite and Postgres. + # base = self.store.get_max_thread_subscriptions_stream_id() + + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, # Set to 0, otherwise events will be in timeline, not extension + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_id: {thread_root_id: {}}}}, + ) + + def test_threads_incremental_sync(self) -> None: + """ + Test new thread updates appear in incremental sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the room events stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + # base = self.store.get_room_max_stream_ordering() + + # Initial sync + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + logger.info("Synced to: %r, now subscribing to thread", sync_pos) + + # Do thing + _latest_event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": user1_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + )["event_id"] + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + logger.info("Synced to: %r", sync_pos) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_id: {thread_root_id: {}}}}, + ) + + def test_threads_only_from_joined_rooms(self) -> None: + """ + Test that thread updates are only returned for rooms the user is joined to + at the time of the thread update. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # User1 creates two rooms + room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # User2 joins only Room A + self.helper.join(room_a_id, user2_id, tok=user2_tok) + + # Create threads in both rooms + thread_a_root = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[ + "event_id" + ] + thread_b_root = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room_a_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to A", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_a_root, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_b_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to B", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_b_root, + }, + }, + tok=user1_tok, + ) + + # User2 syncs with threads extension enabled + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user2_tok) + + # Assert: User2 should only see thread from Room A (where they are joined) + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_a_id: {thread_a_root: {}}}}, + "User2 should only see threads from Room A where they are joined, not Room B", + ) + + def test_threads_not_returned_after_leaving_room(self) -> None: + """ + Test that thread updates are properly bounded when a user leaves a room. + + Users should see thread updates that occurred up to the point they left, + but NOT updates that occurred after they left. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create room and both users join + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + self.helper.join(room_id, user2_id, tok=user2_tok) + + # Create thread + thread_root = self.helper.send(room_id, body="Thread root", tok=user1_tok)[ + "event_id" + ] + + # Initial sync for user2 + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user2_tok) + + # Reply in thread while user2 is joined, but after initial sync + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1 while user2 joined", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 leaves the room + self.helper.leave(room_id, user2_id, tok=user2_tok) + + # Another reply after user2 left + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2 after user2 left", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 incremental sync + response_body, _ = self.do_sync(sync_body, tok=user2_tok, since=sync_pos) + + # Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving) + self.assertIn( + EXT_NAME, + response_body["extensions"], + "User2 should see thread updates up to the point they left", + ) + self.assertIn( + room_id, + response_body["extensions"][EXT_NAME]["updates"], + "Thread updates should include the room user2 left", + ) + self.assertIn( + thread_root, + response_body["extensions"][EXT_NAME]["updates"][room_id], + "Thread root should be in the updates", + ) + + # Verify that only a single update was seen (Reply 1) by checking that there's + # no prev_batch token. If Reply 2 was also included, there would be multiple + # updates and a prev_batch token would be present. + thread_update = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root + ] + self.assertNotIn( + "prev_batch", + thread_update, + "No prev_batch should be present since only one update (Reply 1) is visible", + ) + + def test_threads_with_include_roots_true(self) -> None: + """ + Test that include_roots=True returns thread root events with latest_event + in the unsigned field. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply to thread + latest_event_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Latest reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + latest_event_id = latest_event_resp["event_id"] + + # Sync with include_roots=True + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert thread root is present + thread_root = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["thread_root"] + + # Verify it's the correct event + self.assertEqual(thread_root["event_id"], thread_root_id) + self.assertEqual(thread_root["content"]["body"], "Thread root") + + # Verify latest_event is in unsigned.m.relations.m.thread + latest_event = thread_root["unsigned"]["m.relations"]["m.thread"][ + "latest_event" + ] + self.assertEqual(latest_event["event_id"], latest_event_id) + self.assertEqual(latest_event["content"]["body"], "Latest reply") + + def test_threads_with_include_roots_false(self) -> None: + """ + Test that include_roots=False (or omitted) does not return thread root events. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Sync with include_roots=False (explicitly) + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": False, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert thread update exists but has no thread_root + thread_update = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ] + self.assertNotIn("thread_root", thread_update) + + # Also test with include_roots omitted (should behave the same) + sync_body_no_param = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body_no_param, _ = self.do_sync(sync_body_no_param, tok=user1_tok) + + thread_update_no_param = response_body_no_param["extensions"][EXT_NAME][ + "updates" + ][room_id][thread_root_id] + self.assertNotIn("thread_root", thread_update_no_param) + + def test_per_thread_prev_batch_single_update(self) -> None: + """ + Test that threads with only a single update do NOT get a prev_batch token. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Add ONE reply to thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Single reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Incremental sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Thread update should NOT have prev_batch (only 1 update) + thread_update = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ] + self.assertNotIn( + "prev_batch", + thread_update, + "Threads with single update should not have prev_batch", + ) + + def test_per_thread_prev_batch_multiple_updates(self) -> None: + """ + Test that threads with multiple updates get a prev_batch token that can be + used with /relations endpoint to paginate backwards. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Add MULTIPLE replies to thread + reply1_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "First reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply1_id = reply1_resp["event_id"] + + reply2_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Second reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply2_id = reply2_resp["event_id"] + + reply3_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Third reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply3_id = reply3_resp["event_id"] + + # Incremental sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Thread update SHOULD have prev_batch (3 updates) + prev_batch = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["prev_batch"] + self.assertIsNotNone(prev_batch, "prev_batch should not be None") + + # Now use the prev_batch token with /relations endpoint to paginate backwards + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{room_id}/relations/{thread_root_id}?from={prev_batch}&to={sync_pos}&dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + relations_response = channel.json_body + returned_event_ids = [ + event["event_id"] for event in relations_response["chunk"] + ] + + # Assert: Only the older replies should be returned (not the latest one we already saw) + # The prev_batch token should be exclusive, pointing just before the latest event + self.assertIn( + reply1_id, + returned_event_ids, + "First reply should be in relations response", + ) + self.assertIn( + reply2_id, + returned_event_ids, + "Second reply should be in relations response", + ) + self.assertNotIn( + reply3_id, + returned_event_ids, + "Third reply (latest) should NOT be in relations response - already returned in sliding sync", + ) + + def test_per_thread_prev_batch_on_initial_sync(self) -> None: + """ + Test that threads with multiple updates get prev_batch tokens on initial sync + so clients can paginate through the full thread history. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread with multiple replies BEFORE any sync + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + reply1_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply1_id = reply1_resp["event_id"] + + reply2_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply2_id = reply2_resp["event_id"] + + # Initial sync (no from_token) + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert: Thread update SHOULD have prev_batch on initial sync (2+ updates exist) + prev_batch = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["prev_batch"] + self.assertIsNotNone(prev_batch) + + # Use prev_batch with /relations to fetch the thread history + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{room_id}/relations/{thread_root_id}?from={prev_batch}&dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + relations_response = channel.json_body + returned_event_ids = [ + event["event_id"] for event in relations_response["chunk"] + ] + + # Assert: Only the older reply should be returned (not the latest one we already saw) + # The prev_batch token should be exclusive, pointing just before the latest event + self.assertIn( + reply1_id, + returned_event_ids, + "First reply should be in relations response", + ) + self.assertNotIn( + reply2_id, + returned_event_ids, + "Second reply (latest) should NOT be in relations response - already returned in sliding sync", + ) + + def test_thread_in_timeline_omitted_without_include_roots(self) -> None: + """ + Test that threads with events in the room timeline are omitted from the + extension response when include_roots=False. When all threads are filtered out, + the entire extension should be omitted from the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body: JsonDict = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 5, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": False, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Send a reply to the thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Incremental sync - the reply should be in the timeline + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Extension should be omitted entirely since the only thread with updates + # is already visible in the timeline (include_roots=False) + self.assertNotIn( + EXT_NAME, + response_body.get("extensions", {}), + "Extension should be omitted when all threads are filtered out (in timeline with include_roots=False)", + ) + + def test_thread_in_timeline_included_with_include_roots(self) -> None: + """ + Test that threads with events in the room timeline are still included in the + extension response when include_roots=True, because the client wants the root event. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body: JsonDict = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 5, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Send a reply to the thread + reply_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply_id = reply_resp["event_id"] + + # Incremental sync - the reply should be in the timeline + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: The thread reply should be in the room timeline + room_response = response_body["rooms"][room_id] + timeline_event_ids = [event["event_id"] for event in room_response["timeline"]] + self.assertIn( + reply_id, + timeline_event_ids, + "Thread reply should be in the room timeline", + ) + + # Assert: Thread SHOULD be in extension (include_roots=True) + thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id] + self.assertIn( + thread_root_id, + thread_updates, + "Thread should be included in extension when include_roots=True, even if in timeline", + ) + # Verify the thread root event is present + self.assertIn("thread_root", thread_updates[thread_root_id]) + + def test_threads_only_from_rooms_in_list(self) -> None: + """ + Test that thread updates are only returned for rooms that are in the + sliding sync response, not from all rooms the user is joined to. + + This tests the scenario where a user is joined to multiple rooms but + the room list range/limit means only some rooms are in the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create three rooms + room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_c_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in all three rooms + thread_a_root = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[ + "event_id" + ] + thread_b_root = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[ + "event_id" + ] + thread_c_root = self.helper.send(room_c_id, body="Thread C", tok=user1_tok)[ + "event_id" + ] + + # Do an initial sync to get the sync position and see room ordering + initial_sync_body = { + "lists": { + "all-rooms": { + "ranges": [[0, 2]], + "required_state": [], + "timeline_limit": 0, + } + }, + } + response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok) + + # Add replies to all threads after the initial sync + self.helper.send_event( + room_a_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to A", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_a_root, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_b_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to B", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_b_root, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_c_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to C", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_c_root, + }, + }, + tok=user1_tok, + ) + + # Now do a sync with a limited range that excludes the last room + sync_body = { + "lists": { + "limited-list": { + "ranges": [[0, 1]], # Only include first 2 rooms + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Get which rooms were included in this limited response + included_rooms = set( + response_body["lists"]["limited-list"]["ops"][0]["room_ids"] + ) + excluded_room = ({room_a_id, room_b_id, room_c_id} - included_rooms).pop() + + # Assert: Only threads from rooms in the response should be included + thread_updates = response_body["extensions"][EXT_NAME]["updates"] + + # Check that included rooms have thread updates + for room_id in included_rooms: + self.assertIn( + room_id, + thread_updates, + f"Room {room_id} should have thread updates since it's in the room list", + ) + + # Check that the excluded room is NOT present + self.assertNotIn( + excluded_room, + thread_updates, + f"Room {excluded_room} should NOT have thread updates since it's excluded from the room list", + ) + + def test_wait_for_new_data(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive. + + (Only applies to incremental syncs with a `timeout` specified) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + # Create a thread + thread_root = self.helper.send(room_id, body="Thread root", tok=user1_tok)[ + "event_id" + ] + + sync_body = { + "lists": {}, + "room_subscriptions": { + room_id: { + "required_state": [], + "timeline_limit": 0, + }, + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, from_token = self.do_sync(sync_body, tok=user1_tok) + + # Make an incremental Sliding Sync request with the threads extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint + f"?timeout=10000&pos={from_token}", + content=sync_body, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Send a thread reply to trigger new results + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply in thread", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user2_tok, + ) + # Should respond before the 10 second timeout + channel.await_result(timeout_ms=3000) + self.assertEqual(channel.code, 200, channel.json_body) + + # We should see the new thread update + self.assertIn( + thread_root, + channel.json_body["extensions"][EXT_NAME]["updates"][room_id], + ) + + def test_wait_for_new_data_timeout(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive but + no data ever arrives so we timeout. We're also making sure that the default data + from the threads extension doesn't trigger a false-positive for new data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, from_token = self.do_sync(sync_body, tok=user1_tok) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + f"?timeout=10000&pos={from_token}", + content=sync_body, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Wake-up `notifier.wait_for_events(...)` that will cause us test + # `SlidingSyncResult.__bool__` for new results. + self._bump_notifier_wait_for_events( + user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA + ) + # Block for a little bit more to ensure we don't see any new results. + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=4000) + # Wait for the sync to complete (wait for the rest of the 10 second timeout, + # 5000 + 4000 + 1200 > 10000) + channel.await_result(timeout_ms=1200) + self.assertEqual(channel.code, 200, channel.json_body) + + # Should be no thread updates + self.assertNotIn( + EXT_NAME, + channel.json_body.get("extensions", {}), + )