Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 63d96bf

Browse files
authored
ModuleAPI SSO auth callbacks (#15207)
Signed-off-by: Andrii Yasynyshyn [email protected]
1 parent 579c6be commit 63d96bf

File tree

8 files changed

+56
-2
lines changed

8 files changed

+56
-2
lines changed

changelog.d/15207.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth.

docs/modules/account_validity_callbacks.md

+13
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user
4242
represented by their Matrix user ID.
4343

4444
If multiple modules implement this callback, Synapse runs them all in order.
45+
46+
### `on_user_login`
47+
48+
_First introduced in Synapse v1.98.0_
49+
50+
```python
51+
async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None
52+
```
53+
54+
Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth.
55+
represented by their Matrix user ID.
56+
57+
If multiple modules implement this callback, Synapse runs them all in order.

rust/src/push/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue {
296296
match l.iter().map(SimpleJsonValue::extract).collect() {
297297
Ok(a) => Ok(JsonValue::Array(a)),
298298
Err(e) => Err(PyTypeError::new_err(format!(
299-
"Can't convert to JsonValue::Array: {}",
300-
e
299+
"Can't convert to JsonValue::Array: {e}"
301300
))),
302301
}
303302
} else if let Ok(v) = SimpleJsonValue::extract(ob) {

synapse/handlers/account_validity.py

+16
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ async def on_user_registration(self, user_id: str) -> None:
9898
for callback in self._module_api_callbacks.on_user_registration_callbacks:
9999
await callback(user_id)
100100

101+
async def on_user_login(
102+
self,
103+
user_id: str,
104+
auth_provider_type: Optional[str],
105+
auth_provider_id: Optional[str],
106+
) -> None:
107+
"""Tell third-party modules about a user logins.
108+
109+
Args:
110+
user_id: The mxID of the user.
111+
auth_provider_type: The type of login.
112+
auth_provider_id: The ID of the auth provider.
113+
"""
114+
for callback in self._module_api_callbacks.on_user_login_callbacks:
115+
await callback(user_id, auth_provider_type, auth_provider_id)
116+
101117
@wrap_as_background_process("send_renewals")
102118
async def _send_renewal_emails(self) -> None:
103119
"""Gets the list of users whose account is expiring in the amount of time

synapse/handlers/auth.py

+8
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def __init__(self, hs: "HomeServer"):
212212
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
213213
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
214214
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
215+
self._account_validity_handler = hs.get_account_validity_handler()
215216

216217
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
217218
# as per `rc_login.failed_attempts`.
@@ -1783,6 +1784,13 @@ async def complete_sso_login(
17831784
client_redirect_url, "loginToken", login_token
17841785
)
17851786

1787+
# Run post-login module callback handlers
1788+
await self._account_validity_handler.on_user_login(
1789+
user_id=registered_user_id,
1790+
auth_provider_type=LoginType.SSO,
1791+
auth_provider_id=auth_provider_id,
1792+
)
1793+
17861794
# if the client is whitelisted, we can redirect straight to it
17871795
if client_redirect_url.startswith(self._whitelisted_sso_clients):
17881796
request.redirect(redirect_url)

synapse/module_api/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
ON_LEGACY_ADMIN_REQUEST,
8181
ON_LEGACY_RENEW_CALLBACK,
8282
ON_LEGACY_SEND_MAIL_CALLBACK,
83+
ON_USER_LOGIN_CALLBACK,
8384
ON_USER_REGISTRATION_CALLBACK,
8485
)
8586
from synapse.module_api.callbacks.spamchecker_callbacks import (
@@ -334,6 +335,7 @@ def register_account_validity_callbacks(
334335
*,
335336
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
336337
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
338+
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
337339
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
338340
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
339341
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@@ -345,6 +347,7 @@ def register_account_validity_callbacks(
345347
return self._callbacks.account_validity.register_callbacks(
346348
is_user_expired=is_user_expired,
347349
on_user_registration=on_user_registration,
350+
on_user_login=on_user_login,
348351
on_legacy_send_mail=on_legacy_send_mail,
349352
on_legacy_renew=on_legacy_renew,
350353
on_legacy_admin_request=on_legacy_admin_request,

synapse/module_api/callbacks/account_validity_callbacks.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# Types for callbacks to be registered via the module api
2323
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
2424
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
25+
ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable]
2526
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
2627
# to `/_synapse/client/account_validity`. See `register_callbacks` below.
2728
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
@@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks:
3334
def __init__(self) -> None:
3435
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
3536
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
37+
self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = []
3638
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
3739
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
3840

@@ -44,6 +46,7 @@ def register_callbacks(
4446
self,
4547
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
4648
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
49+
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
4750
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
4851
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
4952
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@@ -55,6 +58,9 @@ def register_callbacks(
5558
if on_user_registration is not None:
5659
self.on_user_registration_callbacks.append(on_user_registration)
5760

61+
if on_user_login is not None:
62+
self.on_user_login_callbacks.append(on_user_login)
63+
5864
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
5965
# an admin one). As part of moving the feature into a module, we need to change
6066
# the path from /_matrix/client/unstable/account_validity/... to

synapse/rest/client/login.py

+8
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self, hs: "HomeServer"):
115115
self.registration_handler = hs.get_registration_handler()
116116
self._sso_handler = hs.get_sso_handler()
117117
self._spam_checker = hs.get_module_api_callbacks().spam_checker
118+
self._account_validity_handler = hs.get_account_validity_handler()
118119

119120
self._well_known_builder = WellKnownBuilder(hs)
120121
self._address_ratelimiter = Ratelimiter(
@@ -470,6 +471,13 @@ async def _complete_login(
470471
device_id=device_id,
471472
)
472473

474+
# execute the callback
475+
await self._account_validity_handler.on_user_login(
476+
user_id,
477+
auth_provider_type=login_submission.get("type"),
478+
auth_provider_id=auth_provider_id,
479+
)
480+
473481
if valid_until_ms is not None:
474482
expires_in_ms = valid_until_ms - self.clock.time_msec()
475483
result["expires_in_ms"] = expires_in_ms

0 commit comments

Comments
 (0)