Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions authentik/enterprise/providers/scim/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from base64 import b64encode
from datetime import timedelta
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import PropertyMock, patch

from django.urls import reverse
from django.utils.timezone import now
Expand All @@ -12,9 +12,7 @@
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.enterprise.tests import enterprise_test
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
Expand Down Expand Up @@ -146,20 +144,8 @@ def test_user_create(self, mock: Mocker):
},
)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
def test_api_create(self):
License.objects.create(key=generate_id())
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
Expand Down
55 changes: 55 additions & 0 deletions authentik/enterprise/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections.abc import Callable
from datetime import timedelta
from functools import wraps
from time import mktime
from unittest.mock import MagicMock, patch

from django.utils.timezone import now

from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import THRESHOLD_READ_ONLY_WEEKS, License
from authentik.lib.generators import generate_id

# Valid license expiry
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)


def enterprise_test(
expiry: int = expiry_valid,
internal_users: int = 100,
external_users: int = 100,
create_key=True,
):
"""Install testing enterprise license"""

def wrapper_outer(func: Callable):

@wraps(func)
def wrapper(*args, **kwargs):
with patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry,
name=generate_id(),
internal_users=internal_users,
external_users=external_users,
)
),
):
if create_key:
License.objects.create(key=generate_id())
return func(*args, **kwargs)

return wrapper

return wrapper_outer
128 changes: 15 additions & 113 deletions authentik/enterprise/tests/test_license.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Enterprise license tests"""

from datetime import timedelta
from time import mktime
from unittest.mock import MagicMock, patch

from django.test import TestCase
Expand All @@ -18,35 +17,20 @@
LicenseUsage,
LicenseUsageStatus,
)
from authentik.lib.generators import generate_id

# Valid license expiry
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
from authentik.enterprise.tests import (
enterprise_test,
expiry_expired,
expiry_expired_read_only,
expiry_soon,
expiry_valid,
)
from authentik.lib.generators import generate_id


class TestEnterpriseLicense(TestCase):
"""Enterprise license tests"""

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
def test_valid(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id())
Expand All @@ -58,18 +42,7 @@ def test_invalid(self):
with self.assertRaises(ValidationError):
License.objects.create(key=generate_id())

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test(create_key=False)
def test_valid_multiple(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id(), expiry=expiry_valid)
Expand All @@ -82,18 +55,7 @@ def test_valid_multiple(self):
self.assertEqual(total.exp, expiry_valid)
self.assertTrue(total.status().is_valid)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
Expand All @@ -108,7 +70,6 @@ def test_valid_multiple(self):
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
Expand All @@ -118,18 +79,7 @@ def test_limit_exceeded_read_only(self):
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
Expand All @@ -144,7 +94,6 @@ def test_limit_exceeded_read_only(self):
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
Expand All @@ -156,18 +105,7 @@ def test_limit_exceeded_user_warning(self):
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
Expand All @@ -182,7 +120,6 @@ def test_limit_exceeded_user_warning(self):
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
Expand All @@ -194,39 +131,16 @@ def test_limit_exceeded_admin_warning(self):
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test(expiry=expiry_expired_read_only)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test(expiry=expiry_expired)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
Expand All @@ -238,23 +152,11 @@ def test_expiry_expired(self):
License.objects.create(key=generate_id(), expiry=expiry_expired)
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test(expiry=expiry_soon)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)
21 changes: 2 additions & 19 deletions authentik/enterprise/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
"""Enterprise metrics tests"""

from unittest.mock import MagicMock, patch

from django.test import TestCase
from prometheus_client import REGISTRY

from authentik.core.models import User
from authentik.core.tests.utils import create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.lib.generators import generate_id
from authentik.enterprise.tests import enterprise_test
from authentik.root.monitoring import monitoring_set


class TestEnterpriseMetrics(TestCase):
"""Enterprise metrics tests"""

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@enterprise_test()
def test_usage_empty(self):
"""Test usage (no users)"""
License.objects.create(key=generate_id())
User.objects.all().delete()
create_test_user()
monitoring_set.send_robust(self)
Expand Down
Loading
Loading