Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JWTComponent #151

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions flama/authentication/components.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import http
import logging
import typing as t

from flama import Component
from flama.authentication import exceptions, jwt
from flama.authentication import exceptions, jwt, types
from flama.exceptions import HTTPException
from flama.types import Headers
from flama.types.http import Cookies

logger = logging.getLogger(__name__)

__all__ = ["JWTComponent"]
__all__ = ["AccessTokenComponent", "RefreshTokenComponent"]


class JWTComponent(Component):
def __init__(
self,
secret: bytes,
*,
header_key: str = "Authorization",
header_prefix: str = "Bearer",
cookie_key: str = "flama_authentication",
):
class BaseTokenComponent(Component):
def __init__(self, secret: bytes, *, header_key: str, header_prefix: str, cookie_key: str):
self.secret = secret
self.header_key = header_key
self.header_prefix = header_prefix
self.cookie_key = cookie_key

def _token_from_cookies(self, cookies: Cookies) -> bytes:
print(f"ERROR: {cookies}")
try:
token = cookies[self.cookie_key]["value"]
except KeyError:
Expand All @@ -36,6 +31,7 @@ def _token_from_cookies(self, cookies: Cookies) -> bytes:
return token.encode()

def _token_from_header(self, headers: Headers) -> bytes:
print(f"ERROR: {headers}")
try:
header_prefix, token = headers[self.header_key].split()
except KeyError:
Expand All @@ -55,7 +51,7 @@ def _token_from_header(self, headers: Headers) -> bytes:

return token.encode()

def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
def _resolve_token(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
try:
try:
encoded_token = self._token_from_header(headers)
Expand All @@ -76,3 +72,33 @@ def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
)

return token


class AccessTokenComponent(BaseTokenComponent):
def __init__(
self,
secret: bytes,
*,
header_prefix: str = "Bearer",
header_key: str = "access_token",
cookie_key: str = "access_token",
):
super().__init__(secret, header_prefix=header_prefix, header_key=header_key, cookie_key=cookie_key)

def resolve(self, headers: Headers, cookies: Cookies) -> types.AccessToken:
return t.cast(types.AccessToken, self._resolve_token(headers, cookies))


class RefreshTokenComponent(BaseTokenComponent):
def __init__(
self,
secret: bytes,
*,
header_prefix: str = "Bearer",
header_key: str = "refresh_token",
cookie_key: str = "refresh_token",
):
super().__init__(secret, header_prefix=header_prefix, header_key=header_key, cookie_key=cookie_key)

def resolve(self, headers: Headers, cookies: Cookies) -> types.RefreshToken:
return t.cast(types.RefreshToken, self._resolve_token(headers, cookies))
6 changes: 3 additions & 3 deletions flama/authentication/jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from flama.authentication.jwt.jwt import JWT

__all__ = ["JWT"]
from flama.authentication.components import * # noqa
from flama.authentication.jwt.jwt import JWT # noqa
from flama.authentication.types import * # noqa
6 changes: 4 additions & 2 deletions flama/authentication/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import typing as t

from flama.authentication.jwt.jwt import JWT
from flama import authentication
from flama.exceptions import HTTPException
from flama.http import APIErrorResponse, Request

Expand Down Expand Up @@ -44,7 +44,9 @@ async def _get_response(self, scope: "types.Scope", receive: "types.Receive") ->
return self.app

try:
token: JWT = await app.injector.resolve(JWT).value({"request": Request(scope, receive=receive)})
token: authentication.AccessToken = await app.injector.resolve(authentication.AccessToken).value(
{"request": Request(scope, receive=receive)}
)
except HTTPException as e:
logger.debug("JWT error: %s", e.detail)
return APIErrorResponse(status_code=e.status_code, detail=e.detail)
Expand Down
8 changes: 8 additions & 0 deletions flama/authentication/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typing as t

from flama.authentication.jwt import JWT

__all__ = ["AccessToken", "RefreshToken"]

AccessToken = t.NewType("AccessToken", JWT)
RefreshToken = t.NewType("RefreshToken", JWT)
136 changes: 107 additions & 29 deletions tests/authentication/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,135 @@

import pytest

from flama import Flama
from flama.authentication.components import JWTComponent
from flama.authentication.jwt.jwt import JWT
from flama import Flama, authentication

TOKEN = (
b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyIn0sICJpYXQiOiAwfQ==.J3zdedMZSFNOimstjJat0V"
b"28rM_b1UU62XCp9dg_5kg="
)


class TestCaseJWTComponent:
@pytest.fixture(scope="function")
def secret(self):
return uuid.UUID(int=0)
@pytest.fixture(scope="function")
def secret():
return uuid.UUID(int=0)

@pytest.fixture(scope="function")
def app(self, secret):
return Flama(
schema=None,
docs=None,
components=[
JWTComponent(
secret=secret.bytes,
header_key="Authorization",
header_prefix="Bearer",
cookie_key="flama_authentication",
)
],
)

@pytest.fixture(scope="function")
def app(secret):
return Flama(
schema=None,
docs=None,
components=[
authentication.AccessTokenComponent(secret=secret.bytes),
authentication.RefreshTokenComponent(secret=secret.bytes),
],
)


class TestCaseAccessTokenComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
@app.get("/")
def access_token(token: authentication.AccessToken):
return token.asdict()

@pytest.mark.parametrize(
["params", "status_code", "result"],
(
pytest.param(
{"headers": {"access_token": f"Bearer {TOKEN.decode()}"}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="headers",
),
pytest.param(
{"headers": {"access_token": "token"}},
400,
{
"detail": {
"description": "Authentication header must be 'access_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
"status_code": 400,
},
id="header_wrong_format",
),
pytest.param(
{"headers": {"access_token": "Foo token"}},
400,
{
"detail": {
"description": "Authentication header must be 'access_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
"status_code": 400,
},
id="header_wrong_prefix",
),
pytest.param(
{"cookies": {"access_token": TOKEN.decode()}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="cookies",
),
pytest.param(
{},
401,
{"detail": "Unauthorized", "error": "HTTPException", "status_code": 401},
id="unauthorized",
),
pytest.param(
{
"cookies": {
"access_token": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"n0sICJpYXQiOiAwfQ==.0000",
}
},
401,
{
"detail": {
"description": "Signature verification failed for token 'eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVC"
"J9.eyJkYXRhIjogeyJmb28iOiAiYmFyIn0sICJpYXQiOiAwfQ==.0000'",
"error": "JWTValidateException",
},
"error": "HTTPException",
"status_code": 401,
},
id="invalid_token",
),
),
)
async def test_injection(self, client, status_code, result, params):
response = await client.request("get", "/", **params)

assert response.status_code == status_code
assert response.json() == result


class TestCaseRefreshTokenComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
@app.get("/")
def jwt(token: JWT):
def refresh_token(token: authentication.RefreshToken):
return token.asdict()

@pytest.mark.parametrize(
["params", "status_code", "result"],
(
pytest.param(
{"headers": {"Authorization": f"Bearer {TOKEN.decode()}"}},
{"headers": {"refresh_token": f"Bearer {TOKEN.decode()}"}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="headers",
),
pytest.param(
{"headers": {"Authorization": "token"}},
{"headers": {"refresh_token": "token"}},
400,
{
"detail": {
"description": "Authentication header must be 'Authorization: Bearer <token>'",
"description": "Authentication header must be 'refresh_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
Expand All @@ -61,11 +139,11 @@ def jwt(token: JWT):
id="header_wrong_format",
),
pytest.param(
{"headers": {"Authorization": "Foo token"}},
{"headers": {"refresh_token": "Foo token"}},
400,
{
"detail": {
"description": "Authentication header must be 'Authorization: Bearer <token>'",
"description": "Authentication header must be 'refresh_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
Expand All @@ -74,7 +152,7 @@ def jwt(token: JWT):
id="header_wrong_prefix",
),
pytest.param(
{"cookies": {"flama_authentication": TOKEN.decode()}},
{"cookies": {"refresh_token": TOKEN.decode()}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="cookies",
Expand All @@ -88,7 +166,7 @@ def jwt(token: JWT):
pytest.param(
{
"cookies": {
"flama_authentication": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"refresh_token": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"n0sICJpYXQiOiAwfQ==.0000",
}
},
Expand Down
8 changes: 4 additions & 4 deletions tests/authentication/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from flama import Flama
from flama.authentication.components import JWTComponent
from flama.authentication.components import AccessTokenComponent
from flama.authentication.middlewares import AuthenticationMiddleware
from flama.middleware import Middleware

Expand All @@ -27,7 +27,7 @@ def app(self, secret):
return Flama(
schema=None,
docs=None,
components=[JWTComponent(secret=secret.bytes)],
components=[AccessTokenComponent(secret=secret.bytes)],
middleware=[Middleware(AuthenticationMiddleware)],
)

Expand All @@ -47,7 +47,7 @@ def headers(self, request):
return None

try:
return {"Authorization": f"Bearer {TOKENS[request.param].decode()}"}
return {"access_token": f"Bearer {TOKENS[request.param].decode()}"}
except KeyError:
raise ValueError(f"Invalid token {request.param}")

Expand All @@ -57,7 +57,7 @@ def cookies(self, request):
return None

try:
return {"flama_authentication": TOKENS[request.param].decode()}
return {"access_token": TOKENS[request.param].decode()}
except KeyError:
raise ValueError(f"Invalid token {request.param}")

Expand Down
Loading