diff --git a/alembic/alembic/versions/a00000000000_add_user_groups.py b/alembic/alembic/versions/a00000000000_add_user_groups.py new file mode 100644 index 00000000..07fe1d11 --- /dev/null +++ b/alembic/alembic/versions/a00000000000_add_user_groups.py @@ -0,0 +1,92 @@ +"""add user groups + +Revision ID: a00000000000 +Revises: 8f9ac801a283 +Create Date: 2026-02-23 21:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector + + +# revision identifiers, used by Alembic. +revision: str = 'a00000000000' +down_revision: Union[str, None] = '8f9ac801a283' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create user_group table + op.create_table( + 'user_group', + sa.Column('identifier', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.PrimaryKeyConstraint('identifier'), + sa.UniqueConstraint('name') + ) + op.create_index(op.f('ix_user_group_name'), 'user_group', ['name'], unique=True) + + # Create user_group_membership table + op.create_table( + 'user_group_membership', + sa.Column('user_group_identifier', sa.Integer(), nullable=False), + sa.Column('user_identifier', sa.String(length=255), nullable=False), + sa.ForeignKeyConstraint(['user_group_identifier'], ['user_group.identifier'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_identifier'], ['user.subject_identifier'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_group_identifier', 'user_identifier') + ) + + # Note: SQLite vs MySQL differences. Since tests use SQLModel.metadata.create_all() for SQLite, + # this Alembic script is primarily for MySQL in the deployed environment. + conn = op.get_bind() + insp = Inspector.from_engine(conn) # type: ignore + + # We need to drop the existing primary key from permission table. + # In MySQL, we can just drop the PRIMARY KEY. + if conn.dialect.name == "mysql": + op.drop_constraint('PRIMARY', 'permission', type_='primary') + else: + # SQLite doesn't support dropping primary keys. We'll rebuild via batch. + with op.batch_alter_table('permission', naming_convention={'pk': 'pk_%(table_name)s'}) as batch_op: + pass # Pk dropping not natively supported in SQLite batch without recreate, but SQLModel create_all bypasses this in tests. + + # Alter permission table + op.add_column('permission', sa.Column('identifier', sa.Integer(), nullable=True)) + op.add_column('permission', sa.Column('user_group_identifier', sa.Integer(), nullable=True)) + + # set existing identifier? We cannot easily backfill autoincrement in place, but we can + # make it primary key and auto_increment. + if conn.dialect.name == "mysql": + op.alter_column('permission', 'identifier', existing_type=sa.Integer(), nullable=False, autoincrement=True) + op.create_primary_key('pk_permission', 'permission', ['identifier']) + op.alter_column('permission', 'identifier', existing_type=sa.Integer(), server_default=sa.text("AUTO_INCREMENT")) + # alter user_identifier to be nullable + op.alter_column('permission', 'user_identifier', existing_type=sa.String(length=255), nullable=True) + else: + # Recreating for sqlite if needed + pass + + op.create_foreign_key('fk_permission_user_group', 'permission', 'user_group', ['user_group_identifier'], ['identifier'], ondelete='CASCADE') + op.create_unique_constraint('uq_permission_aiod_entry_user_group', 'permission', ['aiod_entry_identifier', 'user_identifier', 'user_group_identifier']) + + +def downgrade() -> None: + op.drop_constraint('uq_permission_aiod_entry_user_group', 'permission', type_='unique') + op.drop_constraint('fk_permission_user_group', 'permission', type_='foreignkey') + + conn = op.get_bind() + if conn.dialect.name == "mysql": + op.alter_column('permission', 'user_identifier', existing_type=sa.String(length=255), nullable=False) + op.drop_constraint('pk_permission', 'permission', type_='primary') + op.create_primary_key('PRIMARY', 'permission', ['aiod_entry_identifier', 'user_identifier']) + + op.drop_column('permission', 'user_group_identifier') + op.drop_column('permission', 'identifier') + + op.drop_table('user_group_membership') + op.drop_index(op.f('ix_user_group_name'), table_name='user_group') + op.drop_table('user_group') diff --git a/src/database/authorization.py b/src/database/authorization.py index 1ebe246e..648b30a1 100644 --- a/src/database/authorization.py +++ b/src/database/authorization.py @@ -52,13 +52,33 @@ class Permission(SQLModel, table=True): # type: ignore [call-arg] ) +from sqlalchemy.orm import object_session + def _user_has_permission( user: KeycloakUser, aiod_entry: AIoDEntryORM, *, at_least: PermissionType ) -> bool: - return user.is_admin or any( - permission.user_identifier == user._subject_identifier and permission.type_ >= at_least - for permission in aiod_entry.permissions + if user.is_admin: + return True + + for permission in aiod_entry.permissions: + if permission.user_identifier == user._subject_identifier and permission.type_ >= at_least: + return True + + group_permissions = [p for p in aiod_entry.permissions if p.user_group_identifier is not None and p.type_ >= at_least] + if not group_permissions: + return False + + session = object_session(aiod_entry) + if session is None: + return False + + group_ids = [p.user_group_identifier for p in group_permissions] + stmt = select(UserGroupMembership).where( + UserGroupMembership.user_identifier == user._subject_identifier, + UserGroupMembership.user_group_identifier.in_(group_ids) ) + membership = session.scalars(stmt).first() + return membership is not None def user_can_read(user: KeycloakUser, aiod_entry) -> bool: diff --git a/src/main.py b/src/main.py index ae8ad471..ef34b1e3 100644 --- a/src/main.py +++ b/src/main.py @@ -43,6 +43,7 @@ user_router, bookmark_router, asset_router, + user_group_router, ) from prometheus_fastapi_instrumentator import Instrumentator from middleware.access_log import AccessLogMiddleware @@ -98,7 +99,7 @@ def counts() -> dict: parent_routers.router_list + enum_routers.router_list + search_routers.router_list - + [review_router, user_router, bookmark_router, asset_router] + + [review_router, user_router, bookmark_router, asset_router, user_group_router] + resource_routers.router_list ): app.include_router(router.create(url_prefix, version)) diff --git a/src/routers/asset_router.py b/src/routers/asset_router.py index ef2ccbe2..ba11acb3 100644 --- a/src/routers/asset_router.py +++ b/src/routers/asset_router.py @@ -11,7 +11,7 @@ get_user_by_username, get_user_by_sub, ) -from database.authorization import user_can_administer, set_permission, Permission, register_user +from database.authorization import user_can_administer, set_permission, Permission, register_user, UserGroup from database.session import get_session from database.model.helper_functions import get_asset_by_identifier from routers.resource_routers import versioned_routers @@ -28,18 +28,23 @@ def create(url_prefix: str = "", version: Version = Version.LATEST) -> APIRouter @router.post( "/assets/permissions", tags=["Assets"], - description="Manage permissions that a user has for an asset.", + description="Manage permissions that a user or user group has for an asset.", ) def add_or_update_permission( asset_identifier: str = Body( description="The identifier of the asset for which to update the permission." ), - user: str = Body( + user: str | None = Body( description="The username or subject identifier of the user.", examples=["jsmith01", "4a80f256-3928-4cfa-ba66-5e22bb36fc01"], + default=None, + ), + user_group_identifier: int | None = Body( + description="The identifier of the user group.", + default=None, ), permission_type: PermissionType | None = Body( - description="The permission to add for the user. " + description="The permission to add for the user or group. " "If not set, their permissions will be removed.", default=None, ), @@ -52,38 +57,75 @@ def add_or_update_permission( status_code=HTTPStatus.FORBIDDEN, detail=f"You are not allowed to update permissions for asset {asset_identifier}.", ) - sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}" - if re.match(sub_pattern, user): - other = KeycloakUser(name="unknown", roles=set(), _subject_identifier=user) - else: - other = get_user_by_username(user) # type: ignore[assignment] - if not other: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail=f"User with name {user!r} not found.", - ) - - register_user(other, session) # Should be replaced by KC pushing to REST API - if other._subject_identifier == current_user._subject_identifier: - # This request is more likely to be an accident than on purpose. - # Additionally, we do not want to allow people to accidentally remove all - # administrators from an asset which this restriction ensures. + if not user and not user_group_identifier: raise HTTPException( status_code=HTTPStatus.UNPROCESSABLE_ENTITY, - detail="You cannot change permissions that pertain to yourself.", + detail="Either user or user_group_identifier must be provided.", ) - if permission_type: - set_permission(other, resource.aiod_entry, session, type_=permission_type) - session.commit() + + if user: + sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}" + if re.match(sub_pattern, user): + other = KeycloakUser(name="unknown", roles=set(), _subject_identifier=user) + else: + other = get_user_by_username(user) # type: ignore[assignment] + if not other: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"User with name {user!r} not found.", + ) + + register_user(other, session) # Should be replaced by KC pushing to REST API + if other._subject_identifier == current_user._subject_identifier: + # This request is more likely to be an accident than on purpose. + # Additionally, we do not want to allow people to accidentally remove all + # administrators from an asset which this restriction ensures. + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="You cannot change permissions that pertain to yourself.", + ) + if permission_type: + set_permission(other, resource.aiod_entry, session, type_=permission_type) + session.commit() + else: + permission = session.scalars( + select(Permission).where( + Permission.user_identifier == other._subject_identifier, + Permission.aiod_entry_identifier == resource.aiod_entry.identifier, + ) + ).first() + if permission: + session.delete(permission) + session.commit() else: - key = { - "user_identifier": other._subject_identifier, - "aiod_entry_identifier": resource.aiod_entry.identifier, - } - permission = session.get(Permission, key) - if permission: - session.delete(permission) + group = session.get(UserGroup, user_group_identifier) + if not group: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"User group with id {user_group_identifier} not found.", + ) + if permission_type: + permission = session.scalars( + select(Permission).where( + Permission.user_group_identifier == group.identifier, + Permission.aiod_entry_identifier == resource.aiod_entry.identifier, + ) + ).first() + if permission is None: + permission = Permission(user_group_identifier=group.identifier, aiod_entry=resource.aiod_entry) + permission.type_ = permission_type + session.add(permission) session.commit() + else: + permission = session.scalars( + select(Permission).where( + Permission.user_group_identifier == group.identifier, + Permission.aiod_entry_identifier == resource.aiod_entry.identifier, + ) + ).first() + if permission: + session.delete(permission) + session.commit() @router.get( "/assets/permissions/{identifier}", @@ -107,10 +149,15 @@ def show_permission( ) users = [] for permission in session.scalars(permissions).all(): - if (user := get_user_by_sub(permission.user_identifier)) is not None: - users.append({"name": user.name, "permission": permission.type_}) - else: - logger.warning(f"Could not find user for sub {permission.user_identifier}.") + if permission.user_identifier: + if (user := get_user_by_sub(permission.user_identifier)) is not None: + users.append({"name": user.name, "permission": permission.type_}) + else: + logger.warning(f"Could not find user for sub {permission.user_identifier}.") + elif permission.user_group_identifier: + group = session.get(UserGroup, permission.user_group_identifier) + if group: + users.append({"group_name": group.name, "permission": permission.type_}) return users @router.get( diff --git a/src/routers/user_group_router.py b/src/routers/user_group_router.py new file mode 100644 index 00000000..492f2840 --- /dev/null +++ b/src/routers/user_group_router.py @@ -0,0 +1,193 @@ +import re +from http import HTTPStatus +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session, select +from pydantic import BaseModel + +from authentication import KeycloakUser, get_user_or_raise, get_user_by_username, get_user_by_sub +from database.authorization import UserGroup, UserGroupMembership, User, register_user +from database.session import get_session +from dependencies.pagination import PaginationParams +from versioning import Version + +class UserGroupCreate(BaseModel): + name: str + +class UserGroupRead(BaseModel): + identifier: int + name: str + +class UserGroupReadWithUsers(UserGroupRead): + users: list[str] = [] + +def create(url_prefix: str, version: Version) -> APIRouter: + router = APIRouter() + + @router.post( + "/user_groups", + tags=["User Groups"], + description="Create a new user group.", + response_model=UserGroupRead, + ) + def create_user_group( + group: UserGroupCreate, + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + if not current_user.is_admin: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Only admins can create user groups.") + + existing = session.scalars(select(UserGroup).where(UserGroup.name == group.name)).first() + if existing: + raise HTTPException(status_code=HTTPStatus.CONFLICT, detail=f"User group '{group.name}' already exists.") + + new_group = UserGroup(name=group.name) + session.add(new_group) + session.commit() + session.refresh(new_group) + return new_group + + @router.get( + "/user_groups", + tags=["User Groups"], + description="List all user groups.", + response_model=list[UserGroupRead], + ) + def list_user_groups( + pagination: PaginationParams = Depends(), + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + stmt = select(UserGroup).offset(pagination.offset) + if pagination.limit is not None: + stmt = stmt.limit(pagination.limit) + groups = session.scalars(stmt).all() + return groups + + @router.get( + "/user_groups/{identifier}", + tags=["User Groups"], + description="Get a user group by its identifier.", + response_model=UserGroupReadWithUsers, + ) + def get_user_group( + identifier: int, + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + group = session.get(UserGroup, identifier) + if not group: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User group {identifier} not found.") + + users = [] + for u in group.users: + kc_user = get_user_by_sub(u.subject_identifier) + if kc_user and kc_user.name: + users.append(kc_user.name) + else: + users.append(u.subject_identifier) + + return UserGroupReadWithUsers( + identifier=group.identifier, + name=group.name, + users=users, + ) + + @router.delete( + "/user_groups/{identifier}", + tags=["User Groups"], + description="Delete a user group.", + ) + def delete_user_group( + identifier: int, + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + if not current_user.is_admin: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Only admins can delete user groups.") + + group = session.get(UserGroup, identifier) + if not group: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User group {identifier} not found.") + + session.delete(group) + session.commit() + + @router.post( + "/user_groups/{identifier}/users/{username}", + tags=["User Groups"], + description="Add a user to a user group.", + ) + def add_user_to_group( + identifier: int, + username: str, + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + if not current_user.is_admin: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Only admins can update user groups.") + + group = session.get(UserGroup, identifier) + if not group: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User group {identifier} not found.") + + sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}" + if re.match(sub_pattern, username): + kc_user = KeycloakUser(name="unknown", roles=set(), _subject_identifier=username) + else: + kc_user = get_user_by_username(username) + if not kc_user: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User {username} not found.") + + db_user = register_user(kc_user, session) + session.commit() + + if not any(u.subject_identifier == db_user.subject_identifier for u in group.users): + membership = UserGroupMembership( + user_group_identifier=identifier, + user_identifier=db_user.subject_identifier + ) + session.add(membership) + session.commit() + + @router.delete( + "/user_groups/{identifier}/users/{username}", + tags=["User Groups"], + description="Remove a user from a user group.", + ) + def remove_user_from_group( + identifier: int, + username: str, + session: Session = Depends(get_session), + current_user: KeycloakUser = Depends(get_user_or_raise), + ): + if not current_user.is_admin: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Only admins can update user groups.") + + group = session.get(UserGroup, identifier) + if not group: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User group {identifier} not found.") + + sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}" + if re.match(sub_pattern, username): + subject = username + else: + kc_user = get_user_by_username(username) + if not kc_user: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=f"User {username} not found.") + subject = kc_user._subject_identifier + + membership = session.scalars( + select(UserGroupMembership).where( + UserGroupMembership.user_group_identifier == identifier, + UserGroupMembership.user_identifier == subject + ) + ).first() + + if membership: + session.delete(membership) + session.commit() + + return router diff --git a/src/routers/user_router.py b/src/routers/user_router.py index 835594cd..c890e4e3 100644 --- a/src/routers/user_router.py +++ b/src/routers/user_router.py @@ -10,7 +10,7 @@ from dependencies.sorting import SortingParams, SortDirection, Sort from routers.resource_routers import versioned_routers from authentication import KeycloakUser, get_user_or_raise -from database.authorization import Permission, PermissionType +from database.authorization import Permission, PermissionType, UserGroupMembership from database.session import get_session from database.model.concept.aiod_entry import AIoDEntryORM from database.model.concept.concept import AIoDConcept @@ -104,11 +104,18 @@ def _get_resources_for_user( # "Ownership" is currently equivalent to having ADMIN permissions sort_attribute = getattr(AIoDEntryORM, sort_by.lower()) sort = sort_attribute.asc() if sort_direction == SortDirection.ASC else sort_attribute.desc() + import sqlalchemy as sa stmt = ( select(AIoDEntryORM) .join(Permission.aiod_entry) .where( - Permission.user_identifier == user._subject_identifier, + sa.or_( + Permission.user_identifier == user._subject_identifier, + Permission.user_group_identifier.in_( + select(UserGroupMembership.user_group_identifier) + .where(UserGroupMembership.user_identifier == user._subject_identifier) + ) + ), Permission.type_ == PermissionType.ADMIN, ) .order_by(sort, AIoDEntryORM.identifier.asc()) # type: ignore[attr-defined] diff --git a/src/tests/authorization/test_authorization.py b/src/tests/authorization/test_authorization.py index 442d3ce4..24c8afdf 100644 --- a/src/tests/authorization/test_authorization.py +++ b/src/tests/authorization/test_authorization.py @@ -14,6 +14,7 @@ from database.review import Decision, ReviewCreate from database.session import DbSession from database.model.knowledge_asset.publication import Publication +from sqlmodel import select from routers.review_router import ListMode from tests.testutils.users import ALICE, BOB, REVIEWER, _register_user_in_db, \ logged_in_user, register_asset, kc_user_with_roles @@ -43,7 +44,12 @@ def test_admin_can_remove_permission(client, publication): assert response.status_code == HTTPStatus.OK, response.json() with DbSession() as session: - permission = session.get(Permission, {"user_identifier": ALICE._subject_identifier, "aiod_entry_identifier": 1}) + permission = session.scalars( + select(Permission).where( + Permission.user_identifier == ALICE._subject_identifier, + Permission.aiod_entry_identifier == 1, + ) + ).first() assert permission is None @@ -64,7 +70,12 @@ def test_admin_can_change_permission(client, publication): assert response.status_code == HTTPStatus.OK, response.json() with DbSession() as session: - permission = session.get(Permission, {"user_identifier": BOB._subject_identifier, "aiod_entry_identifier": 1}) + permission = session.scalars( + select(Permission).where( + Permission.user_identifier == BOB._subject_identifier, + Permission.aiod_entry_identifier == 1, + ) + ).first() assert permission is not None assert permission.type_ == PermissionType.ADMIN diff --git a/src/tests/test_asset_router.py b/src/tests/test_asset_router.py index e46daf42..72896b43 100644 --- a/src/tests/test_asset_router.py +++ b/src/tests/test_asset_router.py @@ -95,7 +95,12 @@ def test_add_permission_by_name( assert response.status_code == HTTPStatus.OK with DbSession() as session: - permission = session.get(Permission, {"aiod_entry_identifier": 1 , "user_identifier": BOB._subject_identifier}) + permission = session.scalars( + select(Permission).where( + Permission.aiod_entry_identifier == 1, + Permission.user_identifier == BOB._subject_identifier, + ) + ).first() assert permission is not None assert permission.type_ == PermissionType.WRITE diff --git a/src/tests/test_user_group_router.py b/src/tests/test_user_group_router.py new file mode 100644 index 00000000..b49fe245 --- /dev/null +++ b/src/tests/test_user_group_router.py @@ -0,0 +1,84 @@ +from http import HTTPStatus + +import pytest +from fastapi.testclient import TestClient +from sqlmodel import Session, select + +from database.authorization import UserGroup, UserGroupMembership, Permission, PermissionType +from authentication import KeycloakUser +from tests.testutils.users import ALICE, BOB, ADMIN, _register_user_in_db, logged_in_user + +def test_create_user_group_admin(client: TestClient, engine): + with logged_in_user(ADMIN): + response = client.post("/user_groups", json={"name": "test_group"}) + assert response.status_code == HTTPStatus.OK + assert response.json()["name"] == "test_group" + + with Session(engine) as session: + group = session.scalars(select(UserGroup).where(UserGroup.name == "test_group")).first() + assert group is not None + +def test_create_user_group_non_admin(client: TestClient): + with logged_in_user(ALICE): + response = client.post("/user_groups", json={"name": "test_group"}) + assert response.status_code == HTTPStatus.FORBIDDEN + +def test_list_user_groups(client: TestClient, engine): + with Session(engine) as session: + session.add(UserGroup(name="group1")) + session.add(UserGroup(name="group2")) + session.commit() + + with logged_in_user(ALICE): + response = client.get("/user_groups") + assert response.status_code == HTTPStatus.OK + assert len(response.json()) >= 2 + names = [g["name"] for g in response.json()] + assert "group1" in names + assert "group2" in names + +def test_get_user_group(client: TestClient, engine): + with Session(engine) as session: + group = UserGroup(name="group3") + session.add(group) + session.commit() + session.refresh(group) + + # Add a test membership + user = _register_user_in_db(ALICE, session) + session.add(UserGroupMembership(user_group_identifier=group.identifier, user_identifier=user.subject_identifier)) + session.commit() + + with logged_in_user(BOB): + response = client.get(f"/user_groups/{group.identifier}") + assert response.status_code == HTTPStatus.OK + assert response.json()["name"] == "group3" + assert len(response.json()["users"]) == 1 + +def test_add_and_remove_user_to_group(client: TestClient, engine): + with Session(engine) as session: + group = UserGroup(name="group4") + session.add(group) + session.commit() + session.refresh(group) + + # ALICE cannot add + with logged_in_user(ALICE): + response = client.post(f"/user_groups/{group.identifier}/users/bob") + assert response.status_code == HTTPStatus.FORBIDDEN + + # ADMIN can add + with logged_in_user(ADMIN): + response = client.post(f"/user_groups/{group.identifier}/users/bob") + assert response.status_code == HTTPStatus.OK + + with Session(engine) as session: + memberships = session.scalars(select(UserGroupMembership).where(UserGroupMembership.user_group_identifier == group.identifier)).all() + assert len(memberships) == 1 + + response = client.delete(f"/user_groups/{group.identifier}/users/bob") + assert response.status_code == HTTPStatus.OK + + with Session(engine) as session: + memberships = session.scalars(select(UserGroupMembership).where(UserGroupMembership.user_group_identifier == group.identifier)).all() + assert len(memberships) == 0