Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 80 additions & 11 deletions app/access.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
import functools
import inspect

import sqlalchemy as sa
import tornado.web
from sqlalchemy.orm import joinedload

from baselayer.app.custom_exceptions import AccessError # noqa: F401
from baselayer.app.models import DBSession, Role, Token, User # noqa: F401
from baselayer.app.models import ( # noqa: F401
DBSession,
Role,
Token,
User,
async_plain_session_factory,
)


def _token_select_stmt(token_id):
return (
sa.select(Token)
.options(
joinedload(Token.created_by).options(
joinedload(User.acls),
joinedload(User.roles),
)
)
.where(Token.id == token_id)
)


def auth_or_token(method):
Expand All @@ -19,24 +39,58 @@ def auth_or_token(method):

$ curl -v -H "Authorization: token 123efghj" http://localhost:5000/api/endpoint

If `method` is a coroutine function, the token lookup runs against the
async DB engine; otherwise the original sync path is used. The cookie
auth path delegates to `tornado.web.authenticated` in both cases.
"""

if inspect.iscoroutinefunction(method):

@functools.wraps(method)
async def async_wrapper(self, *args, **kwargs):
token_header = self.request.headers.get("Authorization", None)
if token_header is not None and token_header.startswith("token "):
token_id = token_header.replace("token", "").strip()
# Use the import via models module so monkeypatching/late
# init by init_db() is reflected here.
from baselayer.app import models as _models

async with _models.async_plain_session_factory() as session:
result = await session.scalars(_token_select_stmt(token_id))
token = result.first()
if token is not None:
self.current_user = token
if not token.created_by.is_active():
raise tornado.web.HTTPError(403, "User account expired")
else:
raise tornado.web.HTTPError(401)
return await method(self, *args, **kwargs)
else:
if self.current_user is not None:
if not self.current_user.is_active():
raise tornado.web.HTTPError(403, "User account expired")
else:
raise tornado.web.HTTPError(
401,
'Credentials malformed; expected form "Authorization: token abc123"',
)
# tornado.web.authenticated returns whatever the method
# returns; for an async method that's a coroutine to await.
result = tornado.web.authenticated(method)(self, *args, **kwargs)
if inspect.isawaitable(result):
return await result
return result

async_wrapper.__authenticated__ = True
return async_wrapper

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
token_header = self.request.headers.get("Authorization", None)
if token_header is not None and token_header.startswith("token "):
token_id = token_header.replace("token", "").strip()
with DBSession() as session:
token = session.scalars(
sa.select(Token)
.options(
joinedload(Token.created_by).options(
joinedload(User.acls),
joinedload(User.roles),
)
)
.where(Token.id == token_id)
).first()
token = session.scalars(_token_select_stmt(token_id)).first()
if token is not None:
self.current_user = token
if not token.created_by.is_active():
Expand Down Expand Up @@ -65,6 +119,21 @@ def permissions(acl_list):
"""

def check_acls(method):
if inspect.iscoroutinefunction(method):

@auth_or_token
@functools.wraps(method)
async def async_wrapper(self, *args, **kwargs):
if not (
set(acl_list).issubset(self.current_user.permissions)
or "System admin" in self.current_user.permissions
):
raise tornado.web.HTTPError(401)
return await method(self, *args, **kwargs)

async_wrapper.__permissions__ = acl_list
return async_wrapper

@auth_or_token
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
Expand Down
57 changes: 53 additions & 4 deletions app/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import uuid
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from json.decoder import JSONDecodeError

# The Python Social Auth base handler gives us:
Expand All @@ -22,7 +22,14 @@
from ..env import load_env
from ..flow import Flow
from ..json_util import to_json
from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id
from ..models import (
AsyncVerifiedSession,
DBSession,
User,
VerifiedSession,
bulk_verify,
session_context_id,
)

env, cfg = load_env()
log = make_log("basehandler")
Expand Down Expand Up @@ -156,6 +163,26 @@ def Session(self):
session.bind = DBSession.session_factory.kw["bind"]
yield session

@asynccontextmanager
async def AsyncSession(self):
"""Async counterpart of `Session()`. Yields an `_AsyncVerifiedSession`
bound to the async engine, with the handler's current user merged so
that commit-time RLS verification has the right accessor.

Usage:
async with self.AsyncSession() as session:
result = await session.scalars(MyModel.select(session.user_or_token))
...
await session.commit()
"""
async with AsyncVerifiedSession(self.current_user) as session:
# Attach the detached current_user (loaded by the auth lookup in
# a different session) without issuing SQL. Relationships that
# were already eager-loaded via `selectin` remain accessible.
merged_user = await session.merge(self.current_user, load=False)
session.user_or_token = merged_user
yield session

def verify_permissions(self):
"""Check that the current user has permission to create, read,
update, or delete rows that are present in the session. If not,
Expand Down Expand Up @@ -394,10 +421,32 @@ def push_notification(self, note, notification_type="info"):
payload={"note": note, "type": notification_type},
)

def get_query_argument(self, value, default=NoValue, **kwargs):
def get_query_argument(self, value, default=NoValue, type=None, **kwargs):
"""Get a query-string argument with optional type coercion.

Parameters
----------
value : str
Name of the query parameter.
default : any, optional
Value to return when the parameter is absent.
type : callable, optional
If provided (e.g. ``float`` / ``int``), the returned string is
passed through this callable. Required for parameters that go
into SQL comparisons against non-text columns — psycopg v3
binds Python strings as VARCHAR, so the database refuses to
compare e.g. ``double precision`` to ``character varying``.
If the value can't be coerced, ``default`` is returned.
"""
if default != NoValue:
kwargs["default"] = default
arg = super().get_query_argument(value, **kwargs)
if isinstance(kwargs.get("default", None), bool):
default_val = kwargs.get("default", None)
if isinstance(default_val, bool):
arg = str(arg).lower() in ["true", "yes", "t", "1"]
elif type is not None and arg is not None and arg is not default_val:
try:
arg = type(arg)
except (TypeError, ValueError):
arg = default_val
return arg
134 changes: 126 additions & 8 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
import warnings
from collections import defaultdict
from contextlib import asynccontextmanager
from datetime import datetime
from hashlib import md5

Expand All @@ -13,6 +14,11 @@
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.asyncio import (
AsyncSession as SAAsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import (
declarative_base,
Expand Down Expand Up @@ -185,6 +191,99 @@ def bulk_verify(mode, collection, accessor):
handle_inaccessible(mode, inaccessible_row_ids, record_cls, accessor)


# --- Async DB layer (parallel to the sync layer above) --------------------
# Configured by init_db(). Sync engine/session remain authoritative for the
# rest of the codebase; async path is opt-in per handler.
async_engine = None
# Verified factory (RLS check on commit), parallel to VerifiedSession.
async_session_factory = None
# Plain factory (no RLS check), parallel to DBSession.
async_plain_session_factory = None


class _AsyncVerifiedSession(SAAsyncSession):
"""Async counterpart of `_VerifiedSession`. Runs RLS verification on
flush/commit using `async_bulk_verify`.

The `user_or_token` attribute is attached by `AsyncVerifiedSession()`
after instantiation; the session is otherwise a plain SQLAlchemy
`AsyncSession`.
"""

user_or_token = None

async def verify(self):
new_rows = list(self.new)
updated_rows = [row for row in self.dirty if self.is_modified(row)]
deleted_rows = list(self.deleted)
read_rows = [
row
for row in set(self.identity_map.values())
- (set(updated_rows) | set(new_rows) | set(deleted_rows))
]

for mode, collection in zip(
["read", "update", "delete"],
[read_rows, updated_rows, deleted_rows],
):
await async_bulk_verify(self, mode, collection, self.user_or_token)

await self.flush()
await async_bulk_verify(self, "create", new_rows, self.user_or_token)

async def commit(self):
await self.verify()
await super().commit()


@asynccontextmanager
async def AsyncVerifiedSession(user_or_token):
"""Async equivalent of `VerifiedSession()`. Yields an
`_AsyncVerifiedSession` bound to the configured async engine.
"""
if async_session_factory is None:
raise RuntimeError(
"Async DB session not initialized. init_db() must run first."
)
session = async_session_factory()
session.user_or_token = user_or_token
try:
yield session
finally:
await session.close()


async def async_bulk_verify(session, mode, collection, accessor):
"""Async counterpart of `bulk_verify`. Runs the RLS leak check inside
the supplied async session rather than the global sync `DBSession`.
"""
grouped_collection = defaultdict(list)
for row in collection:
grouped_collection[type(row)].append(row)

for record_cls, collection in grouped_collection.items():
collection_ids = {record.id for record in collection}

# `cls.select(...)` returns a 2.0-style Select; `.subquery()` is
# statement-level (no I/O) and so works under either dialect.
accessible_row_ids_sq = record_cls.select(
accessor, mode=mode, columns=[record_cls.id]
).subquery()

result = await session.scalars(
sa.select(record_cls.id)
.outerjoin(
accessible_row_ids_sq, record_cls.id == accessible_row_ids_sq.c.id
)
.where(record_cls.id.in_(collection_ids))
.where(accessible_row_ids_sq.c.id.is_(None))
)
inaccessible_row_ids = set(result.all())

if inaccessible_row_ids:
handle_inaccessible(mode, inaccessible_row_ids, record_cls, accessor)


# SQLA1.4 fix to return SQLA1.3-style aliased entity
def safe_aliased(entity):
return sa.orm.aliased(sa.inspect(entity).mapper)
Expand Down Expand Up @@ -217,10 +316,10 @@ def handle_inaccessible(mode, row_ids, row_type, accessor):
raise AccessError(err_msg)


# https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#psycopg2-fast-execution-helpers
# executemany_values_page_size arguments control how many parameter sets
# should be represented in each execution of an INSERT
# 50000 was chosen based on recommendations in the docs and on profiling tests
# Controls how many parameter sets are sent per INSERT under SA 2.0's
# `insertmanyvalues` machinery. 50000 was chosen based on recommendations
# in the docs and profiling tests for psycopg2; psycopg3 honors the same
# dialect-level option.
EXECUTEMANY_PAGESIZE = 50000


Expand Down Expand Up @@ -256,8 +355,10 @@ def init_db(
Default 3600.

"""
url = "postgresql://{}:{}@{}:{}/{}"
url = url.format(user, password or "", host or "", port or "", database)
# Unified on psycopg v3 for both sync and async paths.
url = "postgresql+psycopg://{}:{}@{}:{}/{}".format(
user, password or "", host or "", port or "", database
)

default_engine_args = {
"pool_size": 5,
Expand All @@ -266,8 +367,6 @@ def init_db(
}
conn = sa.create_engine(
url,
client_encoding="utf8",
executemany_mode="values_plus_batch",
insertmanyvalues_page_size=EXECUTEMANY_PAGESIZE,
echo=log_database,
echo_pool=log_database_pool,
Expand All @@ -277,6 +376,25 @@ def init_db(
DBSession.configure(bind=conn, autoflush=autoflush, future=True)
Base.metadata.bind = conn

global async_engine, async_session_factory, async_plain_session_factory
async_engine = create_async_engine(
url,
echo=log_database,
echo_pool=log_database_pool,
**{**default_engine_args, **engine_args},
)
async_session_factory = async_sessionmaker(
bind=async_engine,
class_=_AsyncVerifiedSession,
autoflush=autoflush,
expire_on_commit=False,
)
async_plain_session_factory = async_sessionmaker(
bind=async_engine,
autoflush=autoflush,
expire_on_commit=False,
)

return conn


Expand Down
Loading
Loading