From 0ed0ecb8731572a4c54134efc9e502ff506d33f7 Mon Sep 17 00:00:00 2001 From: EvanCasey13 Date: Fri, 13 Dec 2024 14:15:55 +0000 Subject: [PATCH 1/4] Added annotations for principals when filtering groups by exclude_username or username --- rbac/management/querysets.py | 12 +++--- tests/management/group/test_view.py | 59 +++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/rbac/management/querysets.py b/rbac/management/querysets.py index bb88f65bd..6e0b4c235 100644 --- a/rbac/management/querysets.py +++ b/rbac/management/querysets.py @@ -59,9 +59,9 @@ } -def get_annotated_groups(): +def get_annotated_groups(queryset): """Return an annotated set of groups for the tenant.""" - return Group.objects.annotate( + return queryset.annotate( principalCount=Count("principals", filter=Q(principals__type="user"), distinct=True), policyCount=Count("policies", distinct=True), ) @@ -98,7 +98,7 @@ def get_group_queryset(request, args=None, kwargs=None, base_query: Optional[Que def _gather_group_querysets(request, args, kwargs, base_query: Optional[QuerySet] = None): """Decide which groups to provide for request.""" - base_query = base_query if base_query is not None else get_annotated_groups() + base_query = base_query if base_query is not None else get_annotated_groups(Group.objects.all()) username = request.query_params.get("username") @@ -125,13 +125,15 @@ def _gather_group_querysets(request, args, kwargs, base_query: Optional[QuerySet if principal.cross_account: return Group.objects.none() return ( - filter_queryset_by_tenant(Group.objects.filter(principals__username__iexact=username), request.tenant) + filter_queryset_by_tenant( + get_annotated_groups(base_query.filter(principals__username__iexact=username)), request.tenant + ) | default_group_set ) if exclude_username: return filter_queryset_by_tenant( - Group.objects.exclude(principals__username__iexact=exclude_username), request.tenant + get_annotated_groups(base_query.exclude(principals__username__iexact=exclude_username)), request.tenant ) if has_group_all_access(request): diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 539901e5a..66f94666d 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -540,6 +540,59 @@ def test_read_group_list_principalCount(self, mock_request, sa_mock_request): group = response.data.get("data")[0] self.assertEqual(group["principalCount"], 1) + @patch( + "management.principal.proxy.PrincipalProxy.request_filtered_principals", + return_value={ + "status_code": 200, + "data": [ + { + "org_id": "100001", + "is_org_admin": True, + "is_internal": False, + "id": 52567473, + "username": "test_user", + "account_number": "1111111", + "is_active": True, + } + ], + }, + ) + def test_get_group_principal_count_username_filter(self, mock_request): + "Test that when filtering a group with a username filter that principalCount is returned" + url = reverse("v1_management:group-list") + url = "{}?username={}".format(url, self.test_principal.username) + client = APIClient() + response = client.get(url, **self.test_headers) + + principalCount = response.data.get("data")[0]["principalCount"] + self.assertEqual(principalCount, 2) + + def test_get_group_principal_count_exclude_username_filter(self): + "Test that when filtering a group with the exclude_username filter that principalCount is returned" + # Create test group + group_name = "TestGroup" + group = Group(name=group_name, tenant=self.tenant) + group.save() + + # Create user principal + principal_name = "username_filter_test" + user_principal = Principal(username=principal_name, tenant=self.test_tenant) + user_principal.save() + + # Add principal to group + group.principals.add(user_principal) + group.save() + + # Test that principal count exists & is correct when filtering groups when excluding user + url = f"{reverse('v1_management:group-list')}" + url = "{}?exclude_username={}".format(url, "True") + client = APIClient() + response = client.get(url, **self.headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + principalCount = response.data.get("data")[0]["principalCount"] + self.assertEqual(principalCount, 1) + def test_get_group_by_partial_name_by_default(self): """Test that getting groups by name returns partial match by default.""" url = reverse("v1_management:group-list") @@ -1343,7 +1396,7 @@ def test_get_group_by_username(self, mock_request): url = "{}?username={}".format(url, self.test_principal.username) client = APIClient() response = client.get(url, **self.test_headers) - self.assertEqual(response.data.get("meta").get("count"), 4) + self.assertEqual(response.data.get("meta").get("count"), 2) # Return bad request when user does not exist url = reverse("v1_management:group-list") @@ -1376,7 +1429,7 @@ def test_get_group_by_username_no_assigned_group(self, mock_request): url = "{}?username={}".format(url, self.principalC.username) client = APIClient() response = client.get(url, **self.test_headers) - self.assertEqual(response.data.get("meta").get("count"), 2) + self.assertEqual(response.data.get("meta").get("count"), 1) @patch( "management.principal.proxy.PrincipalProxy.request_filtered_principals", @@ -1432,7 +1485,7 @@ def test_get_group_by_username_with_capitalization(self, mock_request): url = "{}?username={}".format(url, username) client = APIClient() response = client.get(url, **self.test_headers) - self.assertEqual(response.data.get("meta").get("count"), 4) + self.assertEqual(response.data.get("meta").get("count"), 2) def test_get_group_roles_success(self): """Test that getting roles for a group returns successfully.""" From 975c46279d47d1d9bb71aa306def11966433c3d0 Mon Sep 17 00:00:00 2001 From: EvanCasey13 Date: Fri, 13 Dec 2024 14:24:10 +0000 Subject: [PATCH 2/4] test fix --- tests/management/group/test_view.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 66f94666d..764d90f9d 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -591,7 +591,7 @@ def test_get_group_principal_count_exclude_username_filter(self): self.assertEqual(response.status_code, status.HTTP_200_OK) principalCount = response.data.get("data")[0]["principalCount"] - self.assertEqual(principalCount, 1) + self.assertEqual(principalCount, 2) def test_get_group_by_partial_name_by_default(self): """Test that getting groups by name returns partial match by default.""" From edb7b21fa392603b5391b050fdda8bb7b73d056b Mon Sep 17 00:00:00 2001 From: EvanCasey13 Date: Thu, 23 Jan 2025 10:42:38 +0000 Subject: [PATCH 3/4] Issue when filtering by username fix - return all groups principal is in not restricted --- rbac/management/querysets.py | 7 ++++++- tests/management/group/test_view.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/rbac/management/querysets.py b/rbac/management/querysets.py index 0ef3bc4dd..47fa5c0b7 100644 --- a/rbac/management/querysets.py +++ b/rbac/management/querysets.py @@ -121,11 +121,16 @@ def _gather_group_querysets(request, args, kwargs, base_query: Optional[QuerySet username = kwargs.get("principals") if username: principal = get_principal(username, request) + principal_groups = principal.group.all() + principal_group_uuids = [] + for group in principal_groups: + principal_group_uuids.append(group.uuid) + if principal.cross_account: return Group.objects.none() return ( filter_queryset_by_tenant( - get_annotated_groups(base_query.filter(principals__username__iexact=username)), request.tenant + get_annotated_groups(base_query.filter(uuid__in=principal_group_uuids)), request.tenant ) | default_group_set ) diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 764d90f9d..66f94666d 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -591,7 +591,7 @@ def test_get_group_principal_count_exclude_username_filter(self): self.assertEqual(response.status_code, status.HTTP_200_OK) principalCount = response.data.get("data")[0]["principalCount"] - self.assertEqual(principalCount, 2) + self.assertEqual(principalCount, 1) def test_get_group_by_partial_name_by_default(self): """Test that getting groups by name returns partial match by default.""" From 41d4e619485d780d8c661b57280bcde1ec3eeddf Mon Sep 17 00:00:00 2001 From: EvanCasey13 Date: Thu, 30 Jan 2025 11:48:19 +0000 Subject: [PATCH 4/4] test fix for exclude_username --- tests/management/group/test_view.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 66f94666d..05e613cbc 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -583,6 +583,15 @@ def test_get_group_principal_count_exclude_username_filter(self): group.principals.add(user_principal) group.save() + # Create user principal + principal_name = "username_filter_test_2" + user_principal_2 = Principal(username=principal_name, tenant=self.test_tenant) + user_principal_2.save() + + # Add principal to group + group.principals.add(user_principal_2) + group.save() + # Test that principal count exists & is correct when filtering groups when excluding user url = f"{reverse('v1_management:group-list')}" url = "{}?exclude_username={}".format(url, "True") @@ -591,7 +600,7 @@ def test_get_group_principal_count_exclude_username_filter(self): self.assertEqual(response.status_code, status.HTTP_200_OK) principalCount = response.data.get("data")[0]["principalCount"] - self.assertEqual(principalCount, 1) + self.assertEqual(principalCount, 2) def test_get_group_by_partial_name_by_default(self): """Test that getting groups by name returns partial match by default."""