Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions alembic/alembic/versions/a00000000000_add_user_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""add user groups

Revision ID: a00000000000
Revises: 8f9ac801a283
Create Date: 2026-02-23 21:00:00.000000

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector


# revision identifiers, used by Alembic.
revision: str = 'a00000000000'
down_revision: Union[str, None] = '8f9ac801a283'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Create user_group table
op.create_table(
'user_group',
sa.Column('identifier', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('identifier'),
sa.UniqueConstraint('name')
)
op.create_index(op.f('ix_user_group_name'), 'user_group', ['name'], unique=True)

# Create user_group_membership table
op.create_table(
'user_group_membership',
sa.Column('user_group_identifier', sa.Integer(), nullable=False),
sa.Column('user_identifier', sa.String(length=255), nullable=False),
sa.ForeignKeyConstraint(['user_group_identifier'], ['user_group.identifier'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_identifier'], ['user.subject_identifier'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('user_group_identifier', 'user_identifier')
)

# Note: SQLite vs MySQL differences. Since tests use SQLModel.metadata.create_all() for SQLite,
# this Alembic script is primarily for MySQL in the deployed environment.
conn = op.get_bind()
insp = Inspector.from_engine(conn) # type: ignore

# We need to drop the existing primary key from permission table.
# In MySQL, we can just drop the PRIMARY KEY.
if conn.dialect.name == "mysql":
op.drop_constraint('PRIMARY', 'permission', type_='primary')
else:
# SQLite doesn't support dropping primary keys. We'll rebuild via batch.
with op.batch_alter_table('permission', naming_convention={'pk': 'pk_%(table_name)s'}) as batch_op:
pass # Pk dropping not natively supported in SQLite batch without recreate, but SQLModel create_all bypasses this in tests.

# Alter permission table
op.add_column('permission', sa.Column('identifier', sa.Integer(), nullable=True))
op.add_column('permission', sa.Column('user_group_identifier', sa.Integer(), nullable=True))

# set existing identifier? We cannot easily backfill autoincrement in place, but we can
# make it primary key and auto_increment.
if conn.dialect.name == "mysql":
op.alter_column('permission', 'identifier', existing_type=sa.Integer(), nullable=False, autoincrement=True)
op.create_primary_key('pk_permission', 'permission', ['identifier'])
op.alter_column('permission', 'identifier', existing_type=sa.Integer(), server_default=sa.text("AUTO_INCREMENT"))
# alter user_identifier to be nullable
op.alter_column('permission', 'user_identifier', existing_type=sa.String(length=255), nullable=True)
else:
# Recreating for sqlite if needed
pass

op.create_foreign_key('fk_permission_user_group', 'permission', 'user_group', ['user_group_identifier'], ['identifier'], ondelete='CASCADE')
op.create_unique_constraint('uq_permission_aiod_entry_user_group', 'permission', ['aiod_entry_identifier', 'user_identifier', 'user_group_identifier'])


def downgrade() -> None:
op.drop_constraint('uq_permission_aiod_entry_user_group', 'permission', type_='unique')
op.drop_constraint('fk_permission_user_group', 'permission', type_='foreignkey')

conn = op.get_bind()
if conn.dialect.name == "mysql":
op.alter_column('permission', 'user_identifier', existing_type=sa.String(length=255), nullable=False)
op.drop_constraint('pk_permission', 'permission', type_='primary')
op.create_primary_key('PRIMARY', 'permission', ['aiod_entry_identifier', 'user_identifier'])

op.drop_column('permission', 'user_group_identifier')
op.drop_column('permission', 'identifier')

op.drop_table('user_group_membership')
op.drop_index(op.f('ix_user_group_name'), table_name='user_group')
op.drop_table('user_group')
26 changes: 23 additions & 3 deletions src/database/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,33 @@ class Permission(SQLModel, table=True): # type: ignore [call-arg]
)


from sqlalchemy.orm import object_session

def _user_has_permission(
user: KeycloakUser, aiod_entry: AIoDEntryORM, *, at_least: PermissionType
) -> bool:
return user.is_admin or any(
permission.user_identifier == user._subject_identifier and permission.type_ >= at_least
for permission in aiod_entry.permissions
if user.is_admin:
return True

for permission in aiod_entry.permissions:
if permission.user_identifier == user._subject_identifier and permission.type_ >= at_least:
return True

group_permissions = [p for p in aiod_entry.permissions if p.user_group_identifier is not None and p.type_ >= at_least]
if not group_permissions:
return False

session = object_session(aiod_entry)
if session is None:
return False

group_ids = [p.user_group_identifier for p in group_permissions]
stmt = select(UserGroupMembership).where(
UserGroupMembership.user_identifier == user._subject_identifier,
UserGroupMembership.user_group_identifier.in_(group_ids)
)
membership = session.scalars(stmt).first()
return membership is not None


def user_can_read(user: KeycloakUser, aiod_entry) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
user_router,
bookmark_router,
asset_router,
user_group_router,
)
from prometheus_fastapi_instrumentator import Instrumentator
from middleware.access_log import AccessLogMiddleware
Expand Down Expand Up @@ -98,7 +99,7 @@ def counts() -> dict:
parent_routers.router_list
+ enum_routers.router_list
+ search_routers.router_list
+ [review_router, user_router, bookmark_router, asset_router]
+ [review_router, user_router, bookmark_router, asset_router, user_group_router]
+ resource_routers.router_list
):
app.include_router(router.create(url_prefix, version))
Expand Down
117 changes: 82 additions & 35 deletions src/routers/asset_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_user_by_username,
get_user_by_sub,
)
from database.authorization import user_can_administer, set_permission, Permission, register_user
from database.authorization import user_can_administer, set_permission, Permission, register_user, UserGroup
from database.session import get_session
from database.model.helper_functions import get_asset_by_identifier
from routers.resource_routers import versioned_routers
Expand All @@ -28,18 +28,23 @@ def create(url_prefix: str = "", version: Version = Version.LATEST) -> APIRouter
@router.post(
"/assets/permissions",
tags=["Assets"],
description="Manage permissions that a user has for an asset.",
description="Manage permissions that a user or user group has for an asset.",
)
def add_or_update_permission(
asset_identifier: str = Body(
description="The identifier of the asset for which to update the permission."
),
user: str = Body(
user: str | None = Body(
description="The username or subject identifier of the user.",
examples=["jsmith01", "4a80f256-3928-4cfa-ba66-5e22bb36fc01"],
default=None,
),
user_group_identifier: int | None = Body(
description="The identifier of the user group.",
default=None,
),
permission_type: PermissionType | None = Body(
description="The permission to add for the user. "
description="The permission to add for the user or group. "
"If not set, their permissions will be removed.",
default=None,
),
Expand All @@ -52,38 +57,75 @@ def add_or_update_permission(
status_code=HTTPStatus.FORBIDDEN,
detail=f"You are not allowed to update permissions for asset {asset_identifier}.",
)
sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}"
if re.match(sub_pattern, user):
other = KeycloakUser(name="unknown", roles=set(), _subject_identifier=user)
else:
other = get_user_by_username(user) # type: ignore[assignment]
if not other:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"User with name {user!r} not found.",
)

register_user(other, session) # Should be replaced by KC pushing to REST API
if other._subject_identifier == current_user._subject_identifier:
# This request is more likely to be an accident than on purpose.
# Additionally, we do not want to allow people to accidentally remove all
# administrators from an asset which this restriction ensures.
if not user and not user_group_identifier:
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="You cannot change permissions that pertain to yourself.",
detail="Either user or user_group_identifier must be provided.",
)
if permission_type:
set_permission(other, resource.aiod_entry, session, type_=permission_type)
session.commit()

if user:
sub_pattern = r"\S{8}(-\S{4}){3}-\S{12}"
if re.match(sub_pattern, user):
other = KeycloakUser(name="unknown", roles=set(), _subject_identifier=user)
else:
other = get_user_by_username(user) # type: ignore[assignment]
if not other:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"User with name {user!r} not found.",
)

register_user(other, session) # Should be replaced by KC pushing to REST API
if other._subject_identifier == current_user._subject_identifier:
# This request is more likely to be an accident than on purpose.
# Additionally, we do not want to allow people to accidentally remove all
# administrators from an asset which this restriction ensures.
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="You cannot change permissions that pertain to yourself.",
)
if permission_type:
set_permission(other, resource.aiod_entry, session, type_=permission_type)
session.commit()
else:
permission = session.scalars(
select(Permission).where(
Permission.user_identifier == other._subject_identifier,
Permission.aiod_entry_identifier == resource.aiod_entry.identifier,
)
).first()
if permission:
session.delete(permission)
session.commit()
else:
key = {
"user_identifier": other._subject_identifier,
"aiod_entry_identifier": resource.aiod_entry.identifier,
}
permission = session.get(Permission, key)
if permission:
session.delete(permission)
group = session.get(UserGroup, user_group_identifier)
if not group:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"User group with id {user_group_identifier} not found.",
)
if permission_type:
permission = session.scalars(
select(Permission).where(
Permission.user_group_identifier == group.identifier,
Permission.aiod_entry_identifier == resource.aiod_entry.identifier,
)
).first()
if permission is None:
permission = Permission(user_group_identifier=group.identifier, aiod_entry=resource.aiod_entry)
permission.type_ = permission_type
session.add(permission)
session.commit()
else:
permission = session.scalars(
select(Permission).where(
Permission.user_group_identifier == group.identifier,
Permission.aiod_entry_identifier == resource.aiod_entry.identifier,
)
).first()
if permission:
session.delete(permission)
session.commit()

@router.get(
"/assets/permissions/{identifier}",
Expand All @@ -107,10 +149,15 @@ def show_permission(
)
users = []
for permission in session.scalars(permissions).all():
if (user := get_user_by_sub(permission.user_identifier)) is not None:
users.append({"name": user.name, "permission": permission.type_})
else:
logger.warning(f"Could not find user for sub {permission.user_identifier}.")
if permission.user_identifier:
if (user := get_user_by_sub(permission.user_identifier)) is not None:
users.append({"name": user.name, "permission": permission.type_})
else:
logger.warning(f"Could not find user for sub {permission.user_identifier}.")
elif permission.user_group_identifier:
group = session.get(UserGroup, permission.user_group_identifier)
if group:
users.append({"group_name": group.name, "permission": permission.type_})
return users

@router.get(
Expand Down
Loading