Skip to content

Commit 4d6f8da

Browse files
authoredMar 20, 2024··
feat: check that token is valid upon instantiation of config object (#28)
* feat: check that token is valid upon instantiation of config object * linter * fix coverage
1 parent 92f8775 commit 4d6f8da

File tree

4 files changed

+43
-24
lines changed

4 files changed

+43
-24
lines changed
 

‎src/bssclient/client/bssclient.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from yarl import URL
1212

1313
from bssclient.client.config import BasicAuthBssConfig, BssConfig, OAuthBssConfig
14-
from bssclient.client.oauth import _OAuthHttpClient
14+
from bssclient.client.oauth import _OAuthHttpClient, token_is_valid
1515
from bssclient.models.aufgabe import AufgabeStats
1616
from bssclient.models.ermittlungsauftrag import Ermittlungsauftrag, _ListOfErmittlungsauftraege
1717

@@ -210,7 +210,7 @@ async def _get_session(self) -> ClientSession:
210210
async with self._session_lock:
211211
if self._bearer_token is None:
212212
self._bearer_token = await self._get_oauth_token()
213-
elif not self._token_is_valid(self._bearer_token):
213+
elif not token_is_valid(self._bearer_token):
214214
await self.close_session()
215215
if self._session is None or self._session.closed:
216216
_logger.info("creating new session")

‎src/bssclient/client/config.py

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator, model_validator
66
from yarl import URL
77

8+
from .oauth import token_is_valid
9+
810

911
class BssConfig(BaseModel):
1012
"""
@@ -102,3 +104,14 @@ def check_secret_or_token_is_present(cls, values): # pylint:disable=no-self-arg
102104
# pylint:disable=line-too-long
103105
"You must provide either client id and secret or a bearer token, but not None of both"
104106
)
107+
108+
@field_validator("bearer_token")
109+
def validate_bearer_token(cls, value):
110+
"""
111+
check that the value is a string
112+
"""
113+
if value is not None and len(value.strip()) > 0:
114+
_token_is_valid = token_is_valid(value)
115+
if not _token_is_valid:
116+
raise ValueError("Invalid bearer token")
117+
return value

‎src/bssclient/client/oauth.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515
_logger = logging.getLogger(__name__)
1616

1717

18+
def token_is_valid(token) -> bool:
19+
"""
20+
returns true iff the token expiration date is far enough in the future. By "enough" I mean:
21+
more than 1 minute (because the clients' request using the token shouldn't take longer than that)
22+
"""
23+
try:
24+
decoded_token = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False})
25+
expiration_timestamp = decoded_token.get("exp")
26+
expiration_datetime = datetime.fromtimestamp(expiration_timestamp)
27+
_logger.debug("Token is valid until %s", expiration_datetime.isoformat())
28+
current_datetime = datetime.utcnow()
29+
token_is_valid_one_minute_into_the_future = expiration_datetime > current_datetime + timedelta(minutes=1)
30+
return token_is_valid_one_minute_into_the_future
31+
except jwt.ExpiredSignatureError:
32+
_logger.info("The token is expired", exc_info=True)
33+
return False
34+
except jwt.InvalidTokenError:
35+
_logger.info("The token is invalid", exc_info=True)
36+
return False
37+
38+
1839
class _ValidateTokenMixin: # pylint:disable=too-few-public-methods
1940
"""
2041
Mixin for classes which need to validate tokens
@@ -23,26 +44,6 @@ class _ValidateTokenMixin: # pylint:disable=too-few-public-methods
2344
def __init__(self):
2445
self._session_lock = asyncio.Lock()
2546

26-
def _token_is_valid(self, token) -> bool:
27-
"""
28-
returns true iff the token expiration date is far enough in the future. By "enough" I mean:
29-
more than 1 minute (because the clients' request using the token shouldn't take longer than that)
30-
"""
31-
try:
32-
decoded_token = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False})
33-
expiration_timestamp = decoded_token.get("exp")
34-
expiration_datetime = datetime.fromtimestamp(expiration_timestamp)
35-
_logger.debug("Token is valid until %s", expiration_datetime.isoformat())
36-
current_datetime = datetime.utcnow()
37-
token_is_valid_one_minute_into_the_future = expiration_datetime > current_datetime + timedelta(minutes=1)
38-
return token_is_valid_one_minute_into_the_future
39-
except jwt.ExpiredSignatureError:
40-
_logger.info("The token is expired", exc_info=True)
41-
return False
42-
except jwt.InvalidTokenError:
43-
_logger.info("The token is invalid", exc_info=True)
44-
return False
45-
4647

4748
class _OAuthHttpClient(_ValidateTokenMixin, ABC): # pylint:disable=too-few-public-methods
4849
"""
@@ -86,7 +87,7 @@ async def _get_oauth_token(self) -> str:
8687
if self._token is None:
8788
_logger.info("Initially retrieving a new token")
8889
self._token = await self._get_new_token()
89-
elif not self._token_is_valid(self._token):
90+
elif not token_is_valid(self._token):
9091
_logger.info("Token is not valid anymore, retrieving a new token")
9192
self._token = await self._get_new_token()
9293
else:

‎unittests/test_bss_client.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from yarl import URL
33

44
from bssclient.client.bssclient import BasicAuthBssClient
5-
from bssclient.client.config import BasicAuthBssConfig
5+
from bssclient.client.config import BasicAuthBssConfig, OAuthBssConfig
66

77

88
@pytest.mark.parametrize(
@@ -22,3 +22,8 @@ def test_get_tld(actual_url: URL, expected_tld: URL):
2222
client = BasicAuthBssClient(config)
2323
actual = client.get_top_level_domain()
2424
assert actual == expected_tld
25+
26+
27+
def test_oauth_config():
28+
with pytest.raises(ValueError):
29+
OAuthBssConfig(server_url=URL("https://bss.example.com"), bearer_token="something-which-is-definittly no token")

0 commit comments

Comments
 (0)
Please sign in to comment.