Skip to content

Commit 116d941

Browse files
committed
Add companion endpoint for backpagination of thread subscriptions
1 parent fa1c01a commit 116d941

File tree

5 files changed

+278
-12
lines changed

5 files changed

+278
-12
lines changed

synapse/handlers/sliding_sync/extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ async def get_thread_subscriptions_extension_response(
932932

933933
to_stream_id = to_token.thread_subscriptions_key
934934

935-
updates = await self.store.get_updated_thread_subscriptions_for_user(
935+
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
936936
user_id=sync_config.user.to_string(),
937937
from_id=from_stream_id,
938938
to_id=to_stream_id,

synapse/rest/client/thread_subscriptions.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
11
from http import HTTPStatus
2-
from typing import Tuple
2+
from typing import Dict, Tuple
3+
4+
import attr
5+
from typing_extensions import TypeAlias
36

47
from synapse._pydantic_compat import StrictBool
58
from synapse.api.errors import Codes, NotFoundError, SynapseError
69
from synapse.http.server import HttpServer
710
from synapse.http.servlet import (
811
RestServlet,
912
parse_and_validate_json_object_from_request,
13+
parse_integer,
14+
parse_string,
1015
)
1116
from synapse.http.site import SynapseRequest
1217
from synapse.rest.client._base import client_patterns
1318
from synapse.server import HomeServer
14-
from synapse.types import JsonDict, RoomID
19+
from synapse.types import (
20+
JsonDict,
21+
RoomID,
22+
SlidingSyncStreamToken,
23+
ThreadSubscriptionsToken,
24+
)
25+
from synapse.types.handlers.sliding_sync import SlidingSyncResult
1526
from synapse.types.rest import RequestBodyModel
1627

28+
_ThreadSubscription: TypeAlias = (
29+
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
30+
)
31+
_ThreadUnsubscription: TypeAlias = (
32+
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
33+
)
34+
1735

1836
class ThreadSubscriptionsRestServlet(RestServlet):
1937
PATTERNS = client_patterns(
@@ -93,6 +111,129 @@ async def on_DELETE(
93111
return HTTPStatus.OK, {}
94112

95113

114+
class ThreadSubscriptionsPaginationRestServlet(RestServlet):
115+
PATTERNS = client_patterns(
116+
"/io.element.msc4308/thread_subscriptions$",
117+
unstable=True,
118+
releases=(),
119+
)
120+
CATEGORY = "Thread Subscriptions requests (unstable)"
121+
122+
# Maximum number of thread subscriptions to return in one request.
123+
MAX_LIMIT = 512
124+
125+
def __init__(self, hs: "HomeServer"):
126+
self.auth = hs.get_auth()
127+
self.is_mine = hs.is_mine
128+
self.store = hs.get_datastores().main
129+
130+
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
131+
requester = await self.auth.get_user_by_req(request)
132+
133+
limit = min(
134+
parse_integer(request, "limit", default=100, negative=False),
135+
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
136+
)
137+
from_end_opt = parse_string(request, "from", required=False)
138+
to_start_opt = parse_string(request, "to", required=False)
139+
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))
140+
141+
if limit <= 0:
142+
raise SynapseError(
143+
HTTPStatus.BAD_REQUEST,
144+
"limit must be greater than 0",
145+
errcode=Codes.INVALID_PARAM,
146+
)
147+
148+
if from_end_opt is not None:
149+
try:
150+
# because of backwards pagination, the `from` token is actually the
151+
# bound closest to the end of the stream
152+
end_stream_id = ThreadSubscriptionsToken.from_string(
153+
from_end_opt
154+
).stream_id
155+
except ValueError:
156+
raise SynapseError(
157+
HTTPStatus.BAD_REQUEST,
158+
"`from` is not a valid token",
159+
errcode=Codes.INVALID_PARAM,
160+
)
161+
else:
162+
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()
163+
164+
if to_start_opt is not None:
165+
# because of backwards pagination, the `to` token is actually the
166+
# bound closest to the start of the stream
167+
try:
168+
start_stream_id = ThreadSubscriptionsToken.from_string(
169+
to_start_opt
170+
).stream_id
171+
except ValueError:
172+
# we also accept sliding sync `pos` tokens on this parameter
173+
try:
174+
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
175+
self.store, to_start_opt
176+
)
177+
start_stream_id = (
178+
sliding_sync_pos.stream_token.thread_subscriptions_key
179+
)
180+
except ValueError:
181+
raise SynapseError(
182+
HTTPStatus.BAD_REQUEST,
183+
"`to` is not a valid token",
184+
errcode=Codes.INVALID_PARAM,
185+
)
186+
else:
187+
# the start of time is ID 1; the lower bound is exclusive though
188+
start_stream_id = 0
189+
190+
subscriptions = (
191+
await self.store.get_latest_updated_thread_subscriptions_for_user(
192+
requester.user.to_string(),
193+
from_id=start_stream_id,
194+
to_id=end_stream_id,
195+
limit=limit,
196+
)
197+
)
198+
199+
subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
200+
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
201+
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
202+
if subscribed:
203+
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
204+
attr.asdict(
205+
_ThreadSubscription(
206+
automatic=automatic,
207+
bump_stamp=stream_id,
208+
)
209+
)
210+
)
211+
else:
212+
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
213+
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
214+
)
215+
216+
result: JsonDict = {}
217+
if subscribed_threads:
218+
result["subscribed"] = subscribed_threads
219+
if unsubscribed_threads:
220+
result["unsubscribed"] = unsubscribed_threads
221+
222+
if len(subscriptions) == limit:
223+
# We hit the limit, so there might be more entries to return.
224+
# Generate a new token that has moved backwards, ready for the next
225+
# request.
226+
min_returned_stream_id, _, _, _, _ = subscriptions[0]
227+
result["end"] = ThreadSubscriptionsToken(
228+
# We subtract one because the 'later in the stream' bound is inclusive,
229+
# and we already saw the element at index 0.
230+
stream_id=min_returned_stream_id - 1
231+
).to_string()
232+
233+
return HTTPStatus.OK, result
234+
235+
96236
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
97237
if hs.config.experimental.msc4306_enabled:
98238
ThreadSubscriptionsRestServlet(hs).register(http_server)
239+
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)

synapse/storage/databases/main/thread_subscriptions.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,30 +350,38 @@ def get_updated_thread_subscriptions_txn(
350350
get_updated_thread_subscriptions_txn,
351351
)
352352

353-
async def get_updated_thread_subscriptions_for_user(
353+
async def get_latest_updated_thread_subscriptions_for_user(
354354
self, user_id: str, from_id: int, to_id: int, limit: int
355355
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
356-
"""Get updates to thread subscriptions for a specific user.
356+
"""Get the latest updates to thread subscriptions for a specific user.
357357
358358
Args:
359359
user_id: The ID of the user
360360
from_id: The starting stream ID (exclusive)
361361
to_id: The ending stream ID (inclusive)
362362
limit: The maximum number of rows to return
363+
If there are too many rows to return, rows from the start (closer to `from_id`)
364+
will be omitted.
363365
364366
Returns:
365367
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
368+
The row with lowest `stream_id` is the first row.
366369
"""
367370

368371
def get_updated_thread_subscriptions_for_user_txn(
369372
txn: LoggingTransaction,
370373
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
371374
sql = """
375+
WITH the_updates AS (
376+
SELECT stream_id, room_id, event_id, subscribed, automatic
377+
FROM thread_subscriptions
378+
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
379+
ORDER BY stream_id DESC
380+
LIMIT ?
381+
)
372382
SELECT stream_id, room_id, event_id, subscribed, automatic
373-
FROM thread_subscriptions
374-
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
383+
FROM the_updates
375384
ORDER BY stream_id ASC
376-
LIMIT ?
377385
"""
378386

379387
txn.execute(sql, (user_id, from_id, to_id, limit))

tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,123 @@ def test_limit_parameter(self) -> None:
353353
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
354354
)
355355

356+
def test_limit_and_companion_backpagination(self) -> None:
357+
"""
358+
Create 1 thread subscription, do a sync, create 4 more,
359+
then sync with a limit of 2 and fill in the gap
360+
using the companion /thread_subscriptions endpoint.
361+
"""
362+
363+
thread_root_ids: List[str] = []
364+
365+
def make_subscription() -> None:
366+
thread_root_resp = self.helper.send(
367+
room_id, body="Some thread root", tok=user1_tok
368+
)
369+
thread_root_ids.append(thread_root_resp["event_id"])
370+
self._subscribe_to_thread(
371+
user1_id, room_id, thread_root_ids[-1], automatic=False
372+
)
373+
374+
user1_id = self.register_user("user1", "pass")
375+
user1_tok = self.login(user1_id, "pass")
376+
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
377+
378+
# get the baseline stream_id of the thread_subscriptions stream
379+
# before we write any data.
380+
# Required because the initial value differs between SQLite and Postgres.
381+
base = self.store.get_max_thread_subscriptions_stream_id()
382+
383+
# Make our first subscription
384+
make_subscription()
385+
386+
# Sync for the first time
387+
sync_body = {
388+
"lists": {},
389+
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
390+
}
391+
392+
sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)
393+
394+
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
395+
self.assertEqual(
396+
thread_subscriptions["subscribed"],
397+
{
398+
room_id: {
399+
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
400+
}
401+
},
402+
)
403+
404+
# Get our pos for the next sync
405+
first_sync_pos = sync_resp["pos"]
406+
407+
# Create 4 more thread subsrciptions and subscribe to each
408+
for _ in range(5):
409+
make_subscription()
410+
411+
# Now sync again. Our limit is 2,
412+
# so we should get the latest 2 subscriptions,
413+
# with a gap of 3 more subscriptions in the middle
414+
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)
415+
416+
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
417+
self.assertEqual(
418+
thread_subscriptions["subscribed"],
419+
{
420+
room_id: {
421+
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
422+
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
423+
}
424+
},
425+
)
426+
# 1st backpagination: expecting a page with 2 subscriptions
427+
page, end_tok = self._do_backpaginate(
428+
from_tok=thread_subscriptions["prev_batch"],
429+
to_tok=first_sync_pos,
430+
limit=2,
431+
access_token=user1_tok,
432+
)
433+
self.assertIsNotNone(end_tok, "backpagination should continue")
434+
self.assertEqual(
435+
page["subscribed"],
436+
{
437+
room_id: {
438+
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
439+
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
440+
}
441+
},
442+
)
443+
444+
# 2nd backpagination: expecting a page with only 1 subscription
445+
# and no other token for further backpagination
446+
assert end_tok is not None
447+
page, end_tok = self._do_backpaginate(
448+
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
449+
)
450+
self.assertIsNone(end_tok, "backpagination should have finished")
451+
self.assertEqual(
452+
page["subscribed"],
453+
{
454+
room_id: {
455+
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
456+
}
457+
},
458+
)
459+
460+
def _do_backpaginate(
461+
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
462+
) -> Tuple[JsonDict, Optional[str]]:
463+
channel = self.make_request(
464+
"GET",
465+
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
466+
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
467+
access_token=access_token,
468+
)
469+
470+
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
471+
body = channel.json_body
472+
return body, cast(Optional[str], body.get("end"))
356473

357474
def _subscribe_to_thread(
358475
self, user_id: str, room_id: str, thread_root_id: str, automatic: bool

tests/storage/test_thread_subscriptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_purge_thread_subscriptions_for_user(self) -> None:
182182
self._subscribe(self.other_thread_root_id, automatic=False)
183183

184184
subscriptions = self.get_success(
185-
self.store.get_updated_thread_subscriptions_for_user(
185+
self.store.get_latest_updated_thread_subscriptions_for_user(
186186
self.user_id,
187187
from_id=0,
188188
to_id=50,
@@ -205,7 +205,7 @@ def test_purge_thread_subscriptions_for_user(self) -> None:
205205

206206
# Check user has no subscriptions
207207
subscriptions = self.get_success(
208-
self.store.get_updated_thread_subscriptions_for_user(
208+
self.store.get_latest_updated_thread_subscriptions_for_user(
209209
self.user_id,
210210
from_id=0,
211211
to_id=50,
@@ -255,7 +255,7 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None:
255255

256256
# Get updates for main user
257257
updates = self.get_success(
258-
self.store.get_updated_thread_subscriptions_for_user(
258+
self.store.get_latest_updated_thread_subscriptions_for_user(
259259
self.user_id, 0, stream_id2, 10
260260
)
261261
)
@@ -265,7 +265,7 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None:
265265

266266
# Get updates for other user
267267
updates = self.get_success(
268-
self.store.get_updated_thread_subscriptions_for_user(
268+
self.store.get_latest_updated_thread_subscriptions_for_user(
269269
other_user_id, 0, max(stream_id1, stream_id2), 10
270270
)
271271
)

0 commit comments

Comments
 (0)