diff --git a/app/access.py b/app/access.py index b377869c..e7b6b7a6 100644 --- a/app/access.py +++ b/app/access.py @@ -1,11 +1,31 @@ import functools +import inspect import sqlalchemy as sa import tornado.web from sqlalchemy.orm import joinedload from baselayer.app.custom_exceptions import AccessError # noqa: F401 -from baselayer.app.models import DBSession, Role, Token, User # noqa: F401 +from baselayer.app.models import ( # noqa: F401 + DBSession, + Role, + Token, + User, + async_plain_session_factory, +) + + +def _token_select_stmt(token_id): + return ( + sa.select(Token) + .options( + joinedload(Token.created_by).options( + joinedload(User.acls), + joinedload(User.roles), + ) + ) + .where(Token.id == token_id) + ) def auth_or_token(method): @@ -19,24 +39,58 @@ def auth_or_token(method): $ curl -v -H "Authorization: token 123efghj" http://localhost:5000/api/endpoint + If `method` is a coroutine function, the token lookup runs against the + async DB engine; otherwise the original sync path is used. The cookie + auth path delegates to `tornado.web.authenticated` in both cases. """ + if inspect.iscoroutinefunction(method): + + @functools.wraps(method) + async def async_wrapper(self, *args, **kwargs): + token_header = self.request.headers.get("Authorization", None) + if token_header is not None and token_header.startswith("token "): + token_id = token_header.replace("token", "").strip() + # Use the import via models module so monkeypatching/late + # init by init_db() is reflected here. + from baselayer.app import models as _models + + async with _models.async_plain_session_factory() as session: + result = await session.scalars(_token_select_stmt(token_id)) + token = result.first() + if token is not None: + self.current_user = token + if not token.created_by.is_active(): + raise tornado.web.HTTPError(403, "User account expired") + else: + raise tornado.web.HTTPError(401) + return await method(self, *args, **kwargs) + else: + if self.current_user is not None: + if not self.current_user.is_active(): + raise tornado.web.HTTPError(403, "User account expired") + else: + raise tornado.web.HTTPError( + 401, + 'Credentials malformed; expected form "Authorization: token abc123"', + ) + # tornado.web.authenticated returns whatever the method + # returns; for an async method that's a coroutine to await. + result = tornado.web.authenticated(method)(self, *args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + async_wrapper.__authenticated__ = True + return async_wrapper + @functools.wraps(method) def wrapper(self, *args, **kwargs): token_header = self.request.headers.get("Authorization", None) if token_header is not None and token_header.startswith("token "): token_id = token_header.replace("token", "").strip() with DBSession() as session: - token = session.scalars( - sa.select(Token) - .options( - joinedload(Token.created_by).options( - joinedload(User.acls), - joinedload(User.roles), - ) - ) - .where(Token.id == token_id) - ).first() + token = session.scalars(_token_select_stmt(token_id)).first() if token is not None: self.current_user = token if not token.created_by.is_active(): @@ -65,6 +119,21 @@ def permissions(acl_list): """ def check_acls(method): + if inspect.iscoroutinefunction(method): + + @auth_or_token + @functools.wraps(method) + async def async_wrapper(self, *args, **kwargs): + if not ( + set(acl_list).issubset(self.current_user.permissions) + or "System admin" in self.current_user.permissions + ): + raise tornado.web.HTTPError(401) + return await method(self, *args, **kwargs) + + async_wrapper.__permissions__ = acl_list + return async_wrapper + @auth_or_token @functools.wraps(method) def wrapper(self, *args, **kwargs): diff --git a/app/handlers/base.py b/app/handlers/base.py index 4a66bbe2..a9323461 100644 --- a/app/handlers/base.py +++ b/app/handlers/base.py @@ -1,6 +1,6 @@ import time import uuid -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from json.decoder import JSONDecodeError # The Python Social Auth base handler gives us: @@ -22,7 +22,14 @@ from ..env import load_env from ..flow import Flow from ..json_util import to_json -from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id +from ..models import ( + AsyncVerifiedSession, + DBSession, + User, + VerifiedSession, + bulk_verify, + session_context_id, +) env, cfg = load_env() log = make_log("basehandler") @@ -156,6 +163,26 @@ def Session(self): session.bind = DBSession.session_factory.kw["bind"] yield session + @asynccontextmanager + async def AsyncSession(self): + """Async counterpart of `Session()`. Yields an `_AsyncVerifiedSession` + bound to the async engine, with the handler's current user merged so + that commit-time RLS verification has the right accessor. + + Usage: + async with self.AsyncSession() as session: + result = await session.scalars(MyModel.select(session.user_or_token)) + ... + await session.commit() + """ + async with AsyncVerifiedSession(self.current_user) as session: + # Attach the detached current_user (loaded by the auth lookup in + # a different session) without issuing SQL. Relationships that + # were already eager-loaded via `selectin` remain accessible. + merged_user = await session.merge(self.current_user, load=False) + session.user_or_token = merged_user + yield session + def verify_permissions(self): """Check that the current user has permission to create, read, update, or delete rows that are present in the session. If not, @@ -394,10 +421,32 @@ def push_notification(self, note, notification_type="info"): payload={"note": note, "type": notification_type}, ) - def get_query_argument(self, value, default=NoValue, **kwargs): + def get_query_argument(self, value, default=NoValue, type=None, **kwargs): + """Get a query-string argument with optional type coercion. + + Parameters + ---------- + value : str + Name of the query parameter. + default : any, optional + Value to return when the parameter is absent. + type : callable, optional + If provided (e.g. ``float`` / ``int``), the returned string is + passed through this callable. Required for parameters that go + into SQL comparisons against non-text columns — psycopg v3 + binds Python strings as VARCHAR, so the database refuses to + compare e.g. ``double precision`` to ``character varying``. + If the value can't be coerced, ``default`` is returned. + """ if default != NoValue: kwargs["default"] = default arg = super().get_query_argument(value, **kwargs) - if isinstance(kwargs.get("default", None), bool): + default_val = kwargs.get("default", None) + if isinstance(default_val, bool): arg = str(arg).lower() in ["true", "yes", "t", "1"] + elif type is not None and arg is not None and arg is not default_val: + try: + arg = type(arg) + except (TypeError, ValueError): + arg = default_val return arg diff --git a/app/models.py b/app/models.py index 88fc8cbe..5bb10824 100644 --- a/app/models.py +++ b/app/models.py @@ -3,6 +3,7 @@ import uuid import warnings from collections import defaultdict +from contextlib import asynccontextmanager from datetime import datetime from hashlib import md5 @@ -13,6 +14,11 @@ from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.asyncio import ( + AsyncSession as SAAsyncSession, + async_sessionmaker, + create_async_engine, +) from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import ( declarative_base, @@ -185,6 +191,99 @@ def bulk_verify(mode, collection, accessor): handle_inaccessible(mode, inaccessible_row_ids, record_cls, accessor) +# --- Async DB layer (parallel to the sync layer above) -------------------- +# Configured by init_db(). Sync engine/session remain authoritative for the +# rest of the codebase; async path is opt-in per handler. +async_engine = None +# Verified factory (RLS check on commit), parallel to VerifiedSession. +async_session_factory = None +# Plain factory (no RLS check), parallel to DBSession. +async_plain_session_factory = None + + +class _AsyncVerifiedSession(SAAsyncSession): + """Async counterpart of `_VerifiedSession`. Runs RLS verification on + flush/commit using `async_bulk_verify`. + + The `user_or_token` attribute is attached by `AsyncVerifiedSession()` + after instantiation; the session is otherwise a plain SQLAlchemy + `AsyncSession`. + """ + + user_or_token = None + + async def verify(self): + new_rows = list(self.new) + updated_rows = [row for row in self.dirty if self.is_modified(row)] + deleted_rows = list(self.deleted) + read_rows = [ + row + for row in set(self.identity_map.values()) + - (set(updated_rows) | set(new_rows) | set(deleted_rows)) + ] + + for mode, collection in zip( + ["read", "update", "delete"], + [read_rows, updated_rows, deleted_rows], + ): + await async_bulk_verify(self, mode, collection, self.user_or_token) + + await self.flush() + await async_bulk_verify(self, "create", new_rows, self.user_or_token) + + async def commit(self): + await self.verify() + await super().commit() + + +@asynccontextmanager +async def AsyncVerifiedSession(user_or_token): + """Async equivalent of `VerifiedSession()`. Yields an + `_AsyncVerifiedSession` bound to the configured async engine. + """ + if async_session_factory is None: + raise RuntimeError( + "Async DB session not initialized. init_db() must run first." + ) + session = async_session_factory() + session.user_or_token = user_or_token + try: + yield session + finally: + await session.close() + + +async def async_bulk_verify(session, mode, collection, accessor): + """Async counterpart of `bulk_verify`. Runs the RLS leak check inside + the supplied async session rather than the global sync `DBSession`. + """ + grouped_collection = defaultdict(list) + for row in collection: + grouped_collection[type(row)].append(row) + + for record_cls, collection in grouped_collection.items(): + collection_ids = {record.id for record in collection} + + # `cls.select(...)` returns a 2.0-style Select; `.subquery()` is + # statement-level (no I/O) and so works under either dialect. + accessible_row_ids_sq = record_cls.select( + accessor, mode=mode, columns=[record_cls.id] + ).subquery() + + result = await session.scalars( + sa.select(record_cls.id) + .outerjoin( + accessible_row_ids_sq, record_cls.id == accessible_row_ids_sq.c.id + ) + .where(record_cls.id.in_(collection_ids)) + .where(accessible_row_ids_sq.c.id.is_(None)) + ) + inaccessible_row_ids = set(result.all()) + + if inaccessible_row_ids: + handle_inaccessible(mode, inaccessible_row_ids, record_cls, accessor) + + # SQLA1.4 fix to return SQLA1.3-style aliased entity def safe_aliased(entity): return sa.orm.aliased(sa.inspect(entity).mapper) @@ -217,10 +316,10 @@ def handle_inaccessible(mode, row_ids, row_type, accessor): raise AccessError(err_msg) -# https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#psycopg2-fast-execution-helpers -# executemany_values_page_size arguments control how many parameter sets -# should be represented in each execution of an INSERT -# 50000 was chosen based on recommendations in the docs and on profiling tests +# Controls how many parameter sets are sent per INSERT under SA 2.0's +# `insertmanyvalues` machinery. 50000 was chosen based on recommendations +# in the docs and profiling tests for psycopg2; psycopg3 honors the same +# dialect-level option. EXECUTEMANY_PAGESIZE = 50000 @@ -256,8 +355,10 @@ def init_db( Default 3600. """ - url = "postgresql://{}:{}@{}:{}/{}" - url = url.format(user, password or "", host or "", port or "", database) + # Unified on psycopg v3 for both sync and async paths. + url = "postgresql+psycopg://{}:{}@{}:{}/{}".format( + user, password or "", host or "", port or "", database + ) default_engine_args = { "pool_size": 5, @@ -266,8 +367,6 @@ def init_db( } conn = sa.create_engine( url, - client_encoding="utf8", - executemany_mode="values_plus_batch", insertmanyvalues_page_size=EXECUTEMANY_PAGESIZE, echo=log_database, echo_pool=log_database_pool, @@ -277,6 +376,25 @@ def init_db( DBSession.configure(bind=conn, autoflush=autoflush, future=True) Base.metadata.bind = conn + global async_engine, async_session_factory, async_plain_session_factory + async_engine = create_async_engine( + url, + echo=log_database, + echo_pool=log_database_pool, + **{**default_engine_args, **engine_args}, + ) + async_session_factory = async_sessionmaker( + bind=async_engine, + class_=_AsyncVerifiedSession, + autoflush=autoflush, + expire_on_commit=False, + ) + async_plain_session_factory = async_sessionmaker( + bind=async_engine, + autoflush=autoflush, + expire_on_commit=False, + ) + return conn diff --git a/doc/setup.md b/doc/setup.md index ce2733b7..f9b15ddd 100644 --- a/doc/setup.md +++ b/doc/setup.md @@ -40,13 +40,10 @@ See [below](#configuration) for more information on modifying the baselayer conf - Using `apt-get`: `sudo apt-get install supervisor postgresql libpq-dev nodejs` - If you want to use [brotli compression](https://en.wikipedia.org/wiki/Brotli) with NGINX (better compression rates for the frontend), you have to install NGINX and the brotli module from another source with: + To install NGINX with [brotli compression](https://en.wikipedia.org/wiki/Brotli) support (better compression rates for the frontend), use the brotli modules from Ubuntu's `universe` component (available in 24.04 and later): ``` - sudo apt remove -y nginx nginx-common nginx-core - sudo add-apt-repository ppa:ondrej/nginx-mainline -y - sudo apt update -y - sudo apt install -y nginx libnginx-mod-brotli + sudo apt-get install -y nginx libnginx-mod-http-brotli-static libnginx-mod-http-brotli-filter ``` Otherwise, you can install NGINX normally with `sudo apt-get install nginx`. diff --git a/pyproject.toml b/pyproject.toml index b36baecd..8e317e20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ dependencies = [ "numpy>=1.21.4", "packaging>=23.0", "phonenumbers>=8.12.15", - "psycopg2-binary>=2.8.6", + "psycopg[binary]>=3.2.0", + "greenlet>=3.0.0", "pyjwt>=2.0.1", "python-dateutil>=2.8.1", "python-slugify>=4.0.1",