diff --git a/h/security/predicates.py b/h/security/predicates.py index f38043962a2..093e536db44 100644 --- a/h/security/predicates.py +++ b/h/security/predicates.py @@ -221,6 +221,10 @@ def get_authenticated_users_membership(): def group_member_edit( identity, context: EditGroupMembershipContext ): # pylint:disable=too-many-return-statements,too-complex + assert ( + context.new_roles is not None + ), "new_roles must be set before checking permissions" + old_roles = context.membership.roles new_roles = context.new_roles diff --git a/h/traversal/__init__.py b/h/traversal/__init__.py index e1d328462f1..b3e5493f378 100644 --- a/h/traversal/__init__.py +++ b/h/traversal/__init__.py @@ -63,6 +63,7 @@ from h.traversal.annotation import AnnotationContext, AnnotationRoot from h.traversal.group import GroupContext, GroupRequiredRoot, GroupRoot from h.traversal.group_membership import ( + AddGroupMembershipContext, EditGroupMembershipContext, GroupMembershipContext, group_membership_api_factory, @@ -82,6 +83,7 @@ "UserByIDRoot", "UserRoot", "GroupContext", + "AddGroupMembershipContext", "EditGroupMembershipContext", "GroupMembershipContext", "group_membership_api_factory", diff --git a/h/traversal/group_membership.py b/h/traversal/group_membership.py index cb4ed3b0bd1..16bfd639e44 100644 --- a/h/traversal/group_membership.py +++ b/h/traversal/group_membership.py @@ -10,7 +10,14 @@ class GroupMembershipContext: group: Group user: User - membership: GroupMembership | None + membership: GroupMembership + + +@dataclass +class AddGroupMembershipContext: + group: Group + user: User + new_roles: list[GroupMembershipRoles] | None @dataclass @@ -18,7 +25,7 @@ class EditGroupMembershipContext: group: Group user: User membership: GroupMembership - new_roles: list[GroupMembershipRoles] + new_roles: list[GroupMembershipRoles] | None def _get_user(request, userid) -> User | None: @@ -46,13 +53,14 @@ def _get_membership(request, group, user) -> GroupMembership | None: return group_members_service.get_membership(group, user) -def group_membership_api_factory(request) -> GroupMembershipContext: +def group_membership_api_factory( + request, +) -> GroupMembershipContext | AddGroupMembershipContext | EditGroupMembershipContext: userid = request.matchdict["userid"] pubid = request.matchdict["pubid"] user = _get_user(request, userid) group = _get_group(request, pubid) - membership = _get_membership(request, group, user) if not user: raise HTTPNotFound(f"User not found: {userid}") @@ -60,7 +68,16 @@ def group_membership_api_factory(request) -> GroupMembershipContext: if not group: raise HTTPNotFound(f"Group not found: {pubid}") - if not membership and request.method != "POST": + if request.method == "POST": + return AddGroupMembershipContext(group, user, new_roles=None) + + membership = _get_membership(request, group, user) + + if not membership: raise HTTPNotFound(f"Membership not found: ({pubid}, {userid})") - return GroupMembershipContext(group=group, user=user, membership=membership) + if request.method in ("GET", "DELETE"): + return GroupMembershipContext(group=group, user=user, membership=membership) + + assert request.method == "PATCH" + return EditGroupMembershipContext(group, user, membership, new_roles=None) diff --git a/h/views/api/group_members.py b/h/views/api/group_members.py index 7441de2591e..109ceaa69a4 100644 --- a/h/views/api/group_members.py +++ b/h/views/api/group_members.py @@ -9,7 +9,12 @@ from h.schemas.util import validate_query_params from h.security import Permission from h.services.group_members import ConflictError -from h.traversal import EditGroupMembershipContext, GroupContext, GroupMembershipContext +from h.traversal import ( + AddGroupMembershipContext, + EditGroupMembershipContext, + GroupContext, + GroupMembershipContext, +) from h.views.api.config import api_config from h.views.api.helpers.json_payload import json_payload @@ -108,7 +113,7 @@ def remove_member(context: GroupMembershipContext, request): description="Add a user to a group", permission=Permission.Group.MEMBER_ADD, ) -def add_member(context: GroupMembershipContext, request): +def add_member(context: AddGroupMembershipContext, request): if context.user.authority != context.group.authority: raise HTTPNotFound() @@ -139,21 +144,16 @@ def add_member(context: GroupMembershipContext, request): link_name="group.member.edit", description="Change a user's role in a group", ) -def edit_member(context: GroupMembershipContext, request): +def edit_member(context: EditGroupMembershipContext, request): appstruct = EditGroupMembershipAPISchema().validate(json_payload(request)) - new_roles = appstruct["roles"] - - if not request.has_permission( - Permission.Group.MEMBER_EDIT, - EditGroupMembershipContext( - context.group, context.user, context.membership, new_roles - ), - ): + context.new_roles = appstruct["roles"] + + if not request.has_permission(Permission.Group.MEMBER_EDIT, context): raise HTTPNotFound() - if context.membership.roles != new_roles: + if context.membership.roles != context.new_roles: old_roles = context.membership.roles - context.membership.roles = new_roles + context.membership.roles = context.new_roles log.info( "Changed group membership roles: %r (previous roles were: %r)", context.membership, @@ -166,6 +166,6 @@ def edit_member(context: GroupMembershipContext, request): # Otherwise permissions checks will be based on the old roles. for membership in request.identity.user.memberships: if membership.group.id == context.group.id: - membership.roles = new_roles + membership.roles = context.new_roles return GroupMembershipJSONPresenter(request, context.membership).asdict() diff --git a/tests/unit/h/security/predicates_test.py b/tests/unit/h/security/predicates_test.py index f42bca063b7..b931b340e76 100644 --- a/tests/unit/h/security/predicates_test.py +++ b/tests/unit/h/security/predicates_test.py @@ -1027,6 +1027,20 @@ def test_changing_own_role( assert predicates.group_member_edit(identity, context) == expected_result + def test_it_crashes_if_new_roles_is_not_set(self, identity): + context = EditGroupMembershipContext( + group=sentinel.group, + user=sentinel.user, + membership=sentinel.membership, + new_roles=None, + ) + + with pytest.raises( + AssertionError, + match="^new_roles must be set before checking permissions$", + ): + predicates.group_member_edit(identity, context) + @pytest.fixture def authenticated_user(self, db_session, authenticated_user, factories): # Make the authenticated user a member of a *different* group, diff --git a/tests/unit/h/traversal/group_membership_test.py b/tests/unit/h/traversal/group_membership_test.py index 6a614523002..1b1427963ae 100644 --- a/tests/unit/h/traversal/group_membership_test.py +++ b/tests/unit/h/traversal/group_membership_test.py @@ -6,6 +6,8 @@ from h.exceptions import InvalidUserId from h.traversal.group_membership import ( + AddGroupMembershipContext, + EditGroupMembershipContext, GroupMembershipContext, group_membership_api_factory, ) @@ -13,9 +15,17 @@ @pytest.mark.usefixtures("group_service", "user_service", "group_members_service") class TestGroupMembershipAPIFactory: - def test_it( - self, group_service, user_service, group_members_service, pyramid_request + @pytest.mark.parametrize("request_method", ["GET", "DELETE"]) + def test_get_delete( + self, + group_service, + user_service, + group_members_service, + pyramid_request, + request_method, ): + pyramid_request.method = request_method + context = group_membership_api_factory(pyramid_request) group_service.fetch.assert_called_once_with(sentinel.pubid) @@ -28,25 +38,70 @@ def test_it( assert context.user == user_service.fetch.return_value assert context.membership == group_members_service.get_membership.return_value - def test_when_no_matching_group(self, group_service, pyramid_request): + def test_post( + self, group_service, user_service, group_members_service, pyramid_request + ): + pyramid_request.method = "POST" + + context = group_membership_api_factory(pyramid_request) + + group_service.fetch.assert_called_once_with(sentinel.pubid) + user_service.fetch.assert_called_once_with(sentinel.userid) + group_members_service.get_membership.assert_not_called() + assert isinstance(context, AddGroupMembershipContext) + assert context.group == group_service.fetch.return_value + assert context.user == user_service.fetch.return_value + assert context.new_roles is None + + def test_patch( + self, group_service, user_service, group_members_service, pyramid_request + ): + pyramid_request.method = "PATCH" + + context = group_membership_api_factory(pyramid_request) + + group_service.fetch.assert_called_once_with(sentinel.pubid) + user_service.fetch.assert_called_once_with(sentinel.userid) + group_members_service.get_membership.assert_called_once_with( + group_service.fetch.return_value, user_service.fetch.return_value + ) + assert isinstance(context, EditGroupMembershipContext) + assert context.group == group_service.fetch.return_value + assert context.user == user_service.fetch.return_value + assert context.membership == group_members_service.get_membership.return_value + assert context.new_roles is None + + @pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"]) + def test_when_no_matching_group( + self, group_service, pyramid_request, request_method + ): + pyramid_request.method = request_method group_service.fetch.return_value = None with pytest.raises(HTTPNotFound, match="Group not found: sentinel.pubid"): group_membership_api_factory(pyramid_request) - def test_when_no_matching_user(self, user_service, pyramid_request): + @pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"]) + def test_when_no_matching_user(self, user_service, pyramid_request, request_method): + pyramid_request.method = request_method user_service.fetch.return_value = None with pytest.raises(HTTPNotFound, match="User not found: sentinel.userid"): group_membership_api_factory(pyramid_request) - def test_when_invalid_userid(self, user_service, pyramid_request): + @pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"]) + def test_when_invalid_userid(self, user_service, pyramid_request, request_method): + pyramid_request.method = request_method user_service.fetch.side_effect = InvalidUserId(sentinel.userid) with pytest.raises(HTTPNotFound, match="User not found: sentinel.userid"): group_membership_api_factory(pyramid_request) - def test_when_no_matching_membership(self, group_members_service, pyramid_request): + @pytest.mark.parametrize("request_method", ["GET", "PATCH", "DELETE"]) + def test_when_no_matching_membership( + self, group_members_service, pyramid_request, request_method + ): + pyramid_request.method = request_method group_members_service.get_membership.return_value = None with pytest.raises( @@ -55,7 +110,11 @@ def test_when_no_matching_membership(self, group_members_service, pyramid_reques ): group_membership_api_factory(pyramid_request) - def test_me_alias(self, pyramid_config, pyramid_request, user_service): + @pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"]) + def test_me_alias( + self, pyramid_config, pyramid_request, user_service, request_method + ): + pyramid_request.method = request_method pyramid_config.testing_securitypolicy(userid=sentinel.userid) pyramid_request.matchdict["userid"] = "me" @@ -63,7 +122,9 @@ def test_me_alias(self, pyramid_config, pyramid_request, user_service): user_service.fetch.assert_called_once_with(sentinel.userid) - def test_me_alias_when_not_authenticated(self, pyramid_request): + @pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"]) + def test_me_alias_when_not_authenticated(self, pyramid_request, request_method): + pyramid_request.method = request_method pyramid_request.matchdict["userid"] = "me" with pytest.raises(HTTPNotFound, match="User not found: me"): diff --git a/tests/unit/h/views/api/group_members_test.py b/tests/unit/h/views/api/group_members_test.py index a2aad701979..d4626afb4f6 100644 --- a/tests/unit/h/views/api/group_members_test.py +++ b/tests/unit/h/views/api/group_members_test.py @@ -8,9 +8,15 @@ from h import presenters from h.models import GroupMembership from h.schemas.base import ValidationError +from h.security import Permission from h.security.identity import Identity, LongLivedGroup, LongLivedMembership from h.services.group_members import ConflictError -from h.traversal import GroupContext, GroupMembershipContext +from h.traversal import ( + AddGroupMembershipContext, + EditGroupMembershipContext, + GroupContext, + GroupMembershipContext, +) from h.views.api.exceptions import PayloadError @@ -234,7 +240,7 @@ def test_it_with_authority_mismatch(self, pyramid_request, context): def context(self, factories): group = factories.Group.build() user = factories.User.build(authority=group.authority) - return GroupMembershipContext(group=group, user=user, membership=None) + return AddGroupMembershipContext(group=group, user=user, new_roles=None) @pytest.fixture def pyramid_request(self, pyramid_request): @@ -258,12 +264,17 @@ def test_it( EditGroupMembershipAPISchema, GroupMembershipJSONPresenter, caplog, + mocker, ): + has_permission = mocker.spy(pyramid_request, "has_permission") + response = views.edit_member(context, pyramid_request) EditGroupMembershipAPISchema.return_value.validate.assert_called_once_with( sentinel.json_body ) + assert context.new_roles == sentinel.new_roles + has_permission.assert_called_once_with(Permission.Group.MEMBER_EDIT, context) assert context.membership.roles == sentinel.new_roles GroupMembershipJSONPresenter.assert_called_once_with( pyramid_request, context.membership @@ -342,8 +353,9 @@ def context(self, factories): group = factories.Group.build() user = factories.User.build(authority=group.authority) membership = GroupMembership(group=group, user=user, roles=sentinel.old_roles) - - return GroupMembershipContext(group=group, user=user, membership=membership) + return EditGroupMembershipContext( + group=group, user=user, membership=membership, new_roles=None + ) @pytest.fixture def pyramid_request(self, pyramid_request):