diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8b8dd178..08f2f777 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,10 @@ Unreleased ---------- ========================= +[10.21.16] - 2025-11-21 +----------------------- + * feat: filter out unenrollment and update filter order in Learner Progress Report + [10.21.15] - 2025-11-21 ----------------------- * fix: in operator construction for sql queries diff --git a/enterprise_data/__init__.py b/enterprise_data/__init__.py index e350d96c..0925fb68 100644 --- a/enterprise_data/__init__.py +++ b/enterprise_data/__init__.py @@ -2,4 +2,4 @@ Enterprise data api application. This Django app exposes API endpoints used by enterprises. """ -__version__ = "10.21.15" +__version__ = "10.21.16" diff --git a/enterprise_data/api/v1/views/enterprise_learner.py b/enterprise_data/api/v1/views/enterprise_learner.py index a713f6e0..645a39a6 100644 --- a/enterprise_data/api/v1/views/enterprise_learner.py +++ b/enterprise_data/api/v1/views/enterprise_learner.py @@ -114,6 +114,7 @@ def _stream_serialized_data(self): for page_number in paginator.page_range: yield from serializer(paginator.page(page_number).object_list, many=True).data + # pylint: disable=too-many-statements def apply_filters(self, queryset): """ Filters enrollments based on query params. @@ -181,6 +182,10 @@ def apply_filters(self, queryset): if group_uuid: queryset = self.filter_by_group_uuid(queryset, group_uuid) + search_enrollment = query_filters.get('search_enrollment') + if search_enrollment in ("enrolled", "unenrolled"): + queryset = self.filter_search_enrollment(queryset, search_enrollment) + return queryset def filter_by_group_uuid(self, queryset, group_uuid): @@ -303,6 +308,27 @@ def get_max_created_date(self, queryset): created_max = queryset.aggregate(Max('created')) return created_max['created__max'] + def filter_search_enrollment(self, queryset, status): + """ + Filter enrollments based on enrollment `status`. + + Args: + status (str): Enrollment status to filter by. Can be one of: + 'enrolled' : unenrollment_date is NULL (currently enrolled) + 'unenrolled' : unenrollment_date is NOT NULL (no longer enrolled) + + Returns: + QuerySet: Filtered queryset of enrollments. + """ + + if status == "enrolled": + return queryset.filter(unenrollment_date__isnull=True) + + if status == "unenrolled": + return queryset.filter(unenrollment_date__isnull=False) + + return queryset + @action(detail=False) def overview(self, request, **kwargs): """ diff --git a/enterprise_data/management/commands/create_dummy_data_lpr_v1.py b/enterprise_data/management/commands/create_dummy_data_lpr_v1.py index 9dc7e335..77736085 100644 --- a/enterprise_data/management/commands/create_dummy_data_lpr_v1.py +++ b/enterprise_data/management/commands/create_dummy_data_lpr_v1.py @@ -31,6 +31,7 @@ def handle(self, *args, **options): ) for _ in range(5): EnterpriseLearnerEnrollmentFactory( + enterprise_user=ent_user, enterprise_customer_uuid=enterprise_customer_uuid, enterprise_user_id=ent_user.enterprise_user_id, is_consent_granted=choice([True, False]), diff --git a/enterprise_data/management/commands/create_enterprise_learner_enrollment_lpr_v1.py b/enterprise_data/management/commands/create_enterprise_learner_enrollment_lpr_v1.py index db115302..899b3549 100644 --- a/enterprise_data/management/commands/create_enterprise_learner_enrollment_lpr_v1.py +++ b/enterprise_data/management/commands/create_enterprise_learner_enrollment_lpr_v1.py @@ -8,6 +8,7 @@ from django.core.management.base import BaseCommand, CommandError import enterprise_data.tests.test_utils +from enterprise_data.models import EnterpriseLearner class Command(BaseCommand): @@ -36,9 +37,14 @@ def handle(self, *args, **options): is_consent_granted = options.get('consent_granted') try: + enterprise_learner = EnterpriseLearner.objects.get( + enterprise_customer_uuid=enterprise_customer_uuid, + enterprise_user_id=enterprise_user_id, + ) enterprise_data.tests.test_utils.EnterpriseLearnerEnrollmentFactory( enterprise_customer_uuid=enterprise_customer_uuid, enterprise_user_id=enterprise_user_id, + enterprise_user=enterprise_learner, is_consent_granted=is_consent_granted, ) info = ( diff --git a/enterprise_data/tests/api/v1/test_views.py b/enterprise_data/tests/api/v1/test_views.py index f6082026..2f93dc8f 100644 --- a/enterprise_data/tests/api/v1/test_views.py +++ b/enterprise_data/tests/api/v1/test_views.py @@ -13,6 +13,8 @@ from rest_framework.reverse import reverse from rest_framework.test import APITransactionTestCase +from django.utils import timezone + from enterprise_data.api.v1.serializers import EnterpriseOfferSerializer from enterprise_data.models import EnterpriseLearnerEnrollment, EnterpriseOffer from enterprise_data.tests.factories import ( @@ -369,3 +371,114 @@ def test_retrieve_enterprise_admin_insights_no_access(self): url = reverse('v1:enterprise-admin-insights', kwargs={'enterprise_id': enterprise_customer_uuid}) response = self.client.get(url) assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@ddt.ddt +@mark.django_db +class TestSearchEnrollmentFilter(JWTTestMixin, APITransactionTestCase): + """ + Tests for filtering enrolled vs unenrolled learners using search_enrollment param. + """ + + def setUp(self): + super().setUp() + self.user = UserFactory(is_staff=True) + role, __ = EnterpriseDataFeatureRole.objects.get_or_create(name=ENTERPRISE_DATA_ADMIN_ROLE) + self.role_assignment = EnterpriseDataRoleAssignment.objects.create( + role=role, + user=self.user + ) + self.client.force_authenticate(user=self.user) + + mocked_get_enterprise_customer = mock.patch( + 'enterprise_data.filters.EnterpriseApiClient.get_enterprise_customer', + return_value=get_dummy_enterprise_api_data() + ) + self.mocked_get_enterprise_customer = mocked_get_enterprise_customer.start() + self.addCleanup(mocked_get_enterprise_customer.stop) + + self.enterprise_id = 'fd0d9cd4-bc35-45e8-ba35-e73be3fc5a07' + self.url = reverse( + 'v1:enterprise-learner-enrollment-list', + kwargs={'enterprise_id': self.enterprise_id} + ) + self.enterprise_learner = EnterpriseLearnerFactory( + enterprise_customer_uuid=self.enterprise_id + ) + self.set_jwt_cookie() + + def tearDown(self): + super().tearDown() + EnterpriseLearnerEnrollment.objects.all().delete() + + def create_enrolled(self): + """Enrollment with unenrollment_date = NULL""" + return EnterpriseLearnerEnrollmentFactory( + enterprise_customer_uuid=self.enterprise_id, + enterprise_user_id=self.enterprise_learner.enterprise_user_id, + unenrollment_date=None, + ) + + def create_unenrolled(self, dt=None): + """Enrollment with unenrollment_date != NULL""" + dt = dt or timezone.now() + return EnterpriseLearnerEnrollmentFactory( + enterprise_customer_uuid=self.enterprise_id, + enterprise_user_id=self.enterprise_learner.enterprise_user_id, + unenrollment_date=dt, + ) + + def test_filter_enrolled(self): + """Test for enrolled learners""" + # Create enrolled learner (unenrollment_date = NULL) + enrolled = self.create_enrolled() + # Unenrolled learner(NOT NULL) + self.create_unenrolled() + + response = self.client.get( + self.url, + data={"search_enrollment": "enrolled"} + ) + + results = response.json()["results"] + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["enrollment_id"], enrolled.enrollment_id) + + def test_filter_unenrolled(self): + """Test for unenrolled learners""" + # Enrolled learner (NULL) + self.create_enrolled() + # Unenrolled learner (NOT NULL) + unenrolled = self.create_unenrolled() + + response = self.client.get( + self.url, + data={"search_enrollment": "unenrolled"} + ) + + results = response.json()["results"] + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["enrollment_id"], unenrolled.enrollment_id) + + def test_no_search_enrollment_filter(self): + """Test no filter - return all items""" + self.create_enrolled() + self.create_unenrolled() + + response = self.client.get(self.url) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 2) + + def test_invalid_search_enrollment_value(self): + """Invalid value → return all""" + self.create_enrolled() + self.create_unenrolled() + + response = self.client.get( + self.url, + data={"search_enrollment": "not-valid"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 2) diff --git a/enterprise_data/tests/test_utils.py b/enterprise_data/tests/test_utils.py index e79f471a..3a93db68 100644 --- a/enterprise_data/tests/test_utils.py +++ b/enterprise_data/tests/test_utils.py @@ -154,6 +154,10 @@ class EnterpriseLearnerEnrollmentFactory(factory.django.DjangoModelFactory): class Meta: model = EnterpriseLearnerEnrollment + enterprise_user = factory.SubFactory( + EnterpriseLearnerFactory, + enterprise_user_id=factory.Sequence(lambda n: n+1) + ) enrollment_id = factory.lazy_attribute( lambda x: FAKER.random_int(min=1, max=999999) # pylint: disable=no-member ) @@ -179,6 +183,7 @@ class Meta: letter_grade = factory.lazy_attribute(lambda x: ' '.join(FAKER.words(nb=2)).title()) progress_status = factory.lazy_attribute(lambda x: ' '.join(FAKER.words(nb=2)).title()) enterprise_user_id = factory.Sequence(lambda n: n) + is_consent_granted = True user_email = factory.lazy_attribute(lambda x: FAKER.email()) # pylint: disable=no-member user_username = factory.Sequence('robot{}'.format) user_first_name = factory.Sequence('Robot First {}'.format)