diff --git a/authentik/enterprise/providers/scim/tests.py b/authentik/enterprise/providers/scim/tests.py index 0680c53d0f36..6b44377b8acf 100644 --- a/authentik/enterprise/providers/scim/tests.py +++ b/authentik/enterprise/providers/scim/tests.py @@ -2,7 +2,7 @@ from base64 import b64encode from datetime import timedelta -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import PropertyMock, patch from django.urls import reverse from django.utils.timezone import now @@ -12,9 +12,7 @@ from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application, Group, User from authentik.core.tests.utils import create_test_admin_user -from authentik.enterprise.license import LicenseKey -from authentik.enterprise.models import License -from authentik.enterprise.tests.test_license import expiry_valid +from authentik.enterprise.tests import enterprise_test from authentik.lib.generators import generate_id from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection @@ -146,20 +144,8 @@ def test_user_create(self, mock: Mocker): }, ) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() def test_api_create(self): - License.objects.create(key=generate_id()) self.client.force_login(create_test_admin_user()) res = self.client.post( reverse("authentik_api:scimprovider-list"), diff --git a/authentik/enterprise/tests/__init__.py b/authentik/enterprise/tests/__init__.py index e69de29bb2d1..49477666d462 100644 --- a/authentik/enterprise/tests/__init__.py +++ b/authentik/enterprise/tests/__init__.py @@ -0,0 +1,55 @@ +from collections.abc import Callable +from datetime import timedelta +from functools import wraps +from time import mktime +from unittest.mock import MagicMock, patch + +from django.utils.timezone import now + +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import THRESHOLD_READ_ONLY_WEEKS, License +from authentik.lib.generators import generate_id + +# Valid license expiry +expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple())) +# Valid license expiry, expires soon +expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple())) +# Invalid license expiry, recently expired +expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple())) +# Invalid license expiry, expired longer ago +expiry_expired_read_only = int( + mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple()) +) + + +def enterprise_test( + expiry: int = expiry_valid, + internal_users: int = 100, + external_users: int = 100, + create_key=True, +): + """Install testing enterprise license""" + + def wrapper_outer(func: Callable): + + @wraps(func) + def wrapper(*args, **kwargs): + with patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry, + name=generate_id(), + internal_users=internal_users, + external_users=external_users, + ) + ), + ): + if create_key: + License.objects.create(key=generate_id()) + return func(*args, **kwargs) + + return wrapper + + return wrapper_outer diff --git a/authentik/enterprise/tests/test_license.py b/authentik/enterprise/tests/test_license.py index 6ab2ced0c78d..bf77cbe0bd4f 100644 --- a/authentik/enterprise/tests/test_license.py +++ b/authentik/enterprise/tests/test_license.py @@ -1,7 +1,6 @@ """Enterprise license tests""" from datetime import timedelta -from time import mktime from unittest.mock import MagicMock, patch from django.test import TestCase @@ -18,35 +17,20 @@ LicenseUsage, LicenseUsageStatus, ) -from authentik.lib.generators import generate_id - -# Valid license expiry -expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple())) -# Valid license expiry, expires soon -expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple())) -# Invalid license expiry, recently expired -expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple())) -# Invalid license expiry, expired longer ago -expiry_expired_read_only = int( - mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple()) +from authentik.enterprise.tests import ( + enterprise_test, + expiry_expired, + expiry_expired_read_only, + expiry_soon, + expiry_valid, ) +from authentik.lib.generators import generate_id class TestEnterpriseLicense(TestCase): """Enterprise license tests""" - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() def test_valid(self): """Check license verification""" lic = License.objects.create(key=generate_id()) @@ -58,18 +42,7 @@ def test_invalid(self): with self.assertRaises(ValidationError): License.objects.create(key=generate_id()) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test(create_key=False) def test_valid_multiple(self): """Check license verification""" lic = License.objects.create(key=generate_id(), expiry=expiry_valid) @@ -82,18 +55,7 @@ def test_valid_multiple(self): self.assertEqual(total.exp, expiry_valid) self.assertTrue(total.status().is_valid) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -108,7 +70,6 @@ def test_valid_multiple(self): ) def test_limit_exceeded_read_only(self): """Check license verification""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100, @@ -118,18 +79,7 @@ def test_limit_exceeded_read_only(self): usage.save(update_fields=["record_date"]) self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -144,7 +94,6 @@ def test_limit_exceeded_read_only(self): ) def test_limit_exceeded_user_warning(self): """Check license verification""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100, @@ -156,18 +105,7 @@ def test_limit_exceeded_user_warning(self): LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER ) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -182,7 +120,6 @@ def test_limit_exceeded_user_warning(self): ) def test_limit_exceeded_admin_warning(self): """Check license verification""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100, @@ -194,39 +131,16 @@ def test_limit_exceeded_admin_warning(self): LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN ) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_expired_read_only, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test(expiry=expiry_expired_read_only) @patch( "authentik.enterprise.license.LicenseKey.record_usage", MagicMock(), ) def test_expiry_read_only(self): """Check license verification""" - License.objects.create(key=generate_id()) self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_expired, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test(expiry=expiry_expired) @patch( "authentik.enterprise.license.LicenseKey.record_usage", MagicMock(), @@ -238,23 +152,11 @@ def test_expiry_expired(self): License.objects.create(key=generate_id(), expiry=expiry_expired) self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_soon, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test(expiry=expiry_soon) @patch( "authentik.enterprise.license.LicenseKey.record_usage", MagicMock(), ) def test_expiry_soon(self): """Check license verification""" - License.objects.create(key=generate_id()) self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON) diff --git a/authentik/enterprise/tests/test_metrics.py b/authentik/enterprise/tests/test_metrics.py index f056ecc4eddc..0c3307cf3d3b 100644 --- a/authentik/enterprise/tests/test_metrics.py +++ b/authentik/enterprise/tests/test_metrics.py @@ -1,37 +1,20 @@ """Enterprise metrics tests""" -from unittest.mock import MagicMock, patch - from django.test import TestCase from prometheus_client import REGISTRY from authentik.core.models import User from authentik.core.tests.utils import create_test_user -from authentik.enterprise.license import LicenseKey -from authentik.enterprise.models import License -from authentik.enterprise.tests.test_license import expiry_valid -from authentik.lib.generators import generate_id +from authentik.enterprise.tests import enterprise_test from authentik.root.monitoring import monitoring_set class TestEnterpriseMetrics(TestCase): """Enterprise metrics tests""" - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() def test_usage_empty(self): """Test usage (no users)""" - License.objects.create(key=generate_id()) User.objects.all().delete() create_test_user() monitoring_set.send_robust(self) diff --git a/authentik/enterprise/tests/test_read_only.py b/authentik/enterprise/tests/test_read_only.py index df6da957c916..ba4475c742e7 100644 --- a/authentik/enterprise/tests/test_read_only.py +++ b/authentik/enterprise/tests/test_read_only.py @@ -7,14 +7,13 @@ from django.utils.timezone import now from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user -from authentik.enterprise.license import LicenseKey from authentik.enterprise.models import ( THRESHOLD_READ_ONLY_WEEKS, License, LicenseUsage, LicenseUsageStatus, ) -from authentik.enterprise.tests.test_license import expiry_valid +from authentik.enterprise.tests import enterprise_test from authentik.flows.models import ( FlowDesignation, FlowStageBinding, @@ -30,18 +29,7 @@ class TestReadOnly(FlowTestCase): """Test read_only""" - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -56,7 +44,6 @@ class TestReadOnly(FlowTestCase): ) def test_login(self): """Test flow, ensure login is still possible with read only mode""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100, @@ -115,18 +102,7 @@ def test_login(self): response = self.client.post(exec_url, {"password": user.username}, follow=True) self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test(create_key=False) @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -163,18 +139,7 @@ def test_manage_licenses(self): ) self.assertEqual(response.status_code, 200) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -189,7 +154,6 @@ def test_manage_licenses(self): ) def test_manage_flows(self): """Test flow""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100, @@ -216,18 +180,7 @@ def test_manage_flows(self): ) self.assertEqual(response.status_code, 400) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, - ) - ), - ) + @enterprise_test() @patch( "authentik.enterprise.license.LicenseKey.get_internal_user_count", MagicMock(return_value=1000), @@ -242,7 +195,6 @@ def test_manage_flows(self): ) def test_manage_users(self): """Test that managing users is still possible""" - License.objects.create(key=generate_id()) usage = LicenseUsage.objects.create( internal_user_count=100, external_user_count=100,