diff --git a/src/shipchain_common/authentication.py b/src/shipchain_common/authentication.py index 6edd7e6..1836a50 100644 --- a/src/shipchain_common/authentication.py +++ b/src/shipchain_common/authentication.py @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. """ - +from datetime import timedelta from django.conf import settings -from django.core.cache import cache +from django.utils.functional import cached_property from rest_framework.exceptions import AuthenticationFailed from rest_framework.permissions import BasePermission from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication +from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.models import TokenUser +from rest_framework_simplejwt.utils import aware_utcnow from .utils import parse_dn @@ -89,43 +91,18 @@ def set_password(self, raw_password): def check_password(self, raw_password): raise NotImplementedError('Token users have no DB representation') - def _get_permission_cache_key(self): - """ - Build a unique cache key for this specific JWT. - If no `jti`, `at_hash`, or `sub` and `exp`, then return None - """ - unique_key = self.token.get('jti') - - if not unique_key: - unique_key = self.token.get('at_hash') - - if not unique_key: - sub = self.token.get("sub") - exp = self.token.get("exp") - if sub and exp: - unique_key = f'{sub}.{exp}' - - return unique_key - - def _get_permission_cache_life(self): - """ - Determine cache life from JWT. If exp or iat are not present, or - if calculation results in an invalid life, return the fallback_life - """ - fallback_life = 300 - - exp = self.token.get("exp") - iat = self.token.get("iat") + @cached_property + def _permissions(self): + features = self.token.get('features') + if not features: + return [] - if not exp or not iat: - return fallback_life + permissions = [] + for feature in features: + for permission in features[feature]: + permissions.append(f'{feature}.{permission}') - life = exp - iat - - if not life or life <= 0: - return fallback_life - - return life + return permissions def get_all_permissions(self, obj=None): """ @@ -134,26 +111,16 @@ def get_all_permissions(self, obj=None): This prevents re-parsing the permissions over the lifetime of this token as they will not change until a new token is received """ - permissions = None - unique_key = self._get_permission_cache_key() - - if unique_key: - permissions = cache.get(unique_key) - - if not permissions: - features = self.token.get('features') - if not features: - return [] + try: + # Check token expiration, invalidate cache if expired + self.token.check_exp(current_time=aware_utcnow() + timedelta(seconds=30)) + except TokenError: + try: + del self._permissions + except AttributeError: + pass - permissions = [] - for feature in features: - for permission in features[feature]: - permissions.append(f'{feature}.{permission}') - - if unique_key: - cache.set(unique_key, permissions, self._get_permission_cache_life()) - - return permissions + return self._permissions def has_perm(self, perm, obj=None): """ @@ -166,3 +133,14 @@ def has_perms(self, perm_list, obj=None): Validate perm_list is in token feature permissions """ return all(self.has_perm(perm, obj) for perm in perm_list) + + @property + def limits(self): + return self.token.get('limits', {}) + + def get_limit(self, entity, name): + limit = None + entity = self.limits.get(entity) + if entity: + limit = entity.get(name) + return limit diff --git a/src/shipchain_common/exceptions.py b/src/shipchain_common/exceptions.py index 36f52ae..a68d379 100644 --- a/src/shipchain_common/exceptions.py +++ b/src/shipchain_common/exceptions.py @@ -214,3 +214,9 @@ class URLShortenerError(Custom500Error): status_code = status.HTTP_500_INTERNAL_SERVER_ERROR default_detail = 'URL Shortener Error.' default_code = 'server_error' + + +class AccountLimitReached(APIException): + status_code = status.HTTP_402_PAYMENT_REQUIRED + default_detail = 'Request denied due to the restrictions of your current billing tier.' + default_code = 'account_limit_reached' diff --git a/src/shipchain_common/test_utils/json_asserter.py b/src/shipchain_common/test_utils/json_asserter.py index b120ab4..b00b8c7 100644 --- a/src/shipchain_common/test_utils/json_asserter.py +++ b/src/shipchain_common/test_utils/json_asserter.py @@ -416,6 +416,10 @@ def assert_401(response, error='Authentication credentials were not provided', v assert response.status_code == status.HTTP_401_UNAUTHORIZED, f'status_code {response.status_code} != 401' response_has_error(response, error, vnd=vnd) +def assert_402(response, error='Request denied due to the restrictions of your current billing tier.', vnd=True): + assert response is not None + assert response.status_code == status.HTTP_402_PAYMENT_REQUIRED, f'status_code {response.status_code} != 402' + response_has_error(response, error, vnd=vnd) def assert_403(response, error='You do not have permission to perform this action', vnd=True): assert response is not None @@ -459,6 +463,7 @@ class AssertionHelper: HTTP_400 = assert_400 HTTP_401 = assert_401 + HTTP_402 = assert_402 HTTP_403 = assert_403 HTTP_404 = assert_404 HTTP_405 = assert_405 diff --git a/tests/test_auth.py b/tests/test_auth.py index 22b1dab..8e6593f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -152,54 +152,6 @@ def test_lambda_auth_requires_header(lambda_request): assert lambda_request.has_permission(request, {}) -def test_token_user_jti_cache_key(): - """By default, the jti is included in get_jwt and is used as cache key""" - jwt = get_jwt() - token = UntypedToken(jwt) - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_key() == token_user.token.get('jti') - - -def test_token_user_at_hash_cache_key(): - """If no jti is included in get_jwt then use at_hash as cache key if exists""" - jwt = get_jwt(jti=0, at_hash=uuid4().hex) - token = UntypedToken(jwt) - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_key() == token_user.token.get('at_hash') - - -def test_token_user_sub_exp_cache_key(): - """If no jti or at_hash is included in get_jwt then use {sub}.{exp} as cache key""" - jwt = get_jwt(jti=0, sub=uuid4().hex) - token = UntypedToken(jwt) - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_key() == f'{token_user.token.get("sub")}.{token_user.token.get("exp")}' - - -def test_token_user_cache_life(): - jwt = get_jwt() - token = UntypedToken(jwt) - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_life() == 300 - - -def test_token_user_cache_calculated_life(): - iat = datetime_to_epoch(aware_utcnow()) - jwt = get_jwt(exp=iat+15, iat=iat) - token = UntypedToken(jwt) - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_life() == 15 - - -def test_token_user_cache_fallback_life(): - iat = datetime_to_epoch(aware_utcnow()) - jwt = get_jwt(exp=iat+15, iat=iat) - token = UntypedToken(jwt) - token.payload['iat'] = None - token_user = PermissionedTokenUser(token) - assert token_user._get_permission_cache_life() == 300 - - @pytest.fixture def one_feature(): """Returns feature object response in token, and list of feature permissions"""