Skip to content

Commit

Permalink
🔨 JWT decoding error raises unauthorized exception
Browse files Browse the repository at this point in the history
  • Loading branch information
migduroli committed Sep 3, 2024
1 parent 8836aa7 commit 5bdfd0b
Show file tree
Hide file tree
Showing 12 changed files with 578 additions and 540 deletions.
2 changes: 1 addition & 1 deletion flama/authentication/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
token = jwt.JWT.decode(encoded_token, self.secret)
except (exceptions.JWTDecodeException, exceptions.JWTValidateException) as e:
raise HTTPException(
status_code=http.HTTPStatus.BAD_REQUEST, detail={"error": e.__class__, "description": str(e)}
status_code=http.HTTPStatus.UNAUTHORIZED, detail={"error": e.__class__, "description": str(e)}
)

return token
22 changes: 13 additions & 9 deletions flama/config/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import functools
import inspect
import json
import logging
import os
Expand All @@ -15,6 +14,7 @@

R = t.TypeVar("R")
Unknown = t.NewType("Unknown", str)
unknown = Unknown("")


class Config:
Expand Down Expand Up @@ -80,7 +80,7 @@ def _get_item_from_environment(self, key: str) -> t.Any:
def _get_item_from_config_file(self, key: str) -> t.Any:
return functools.reduce(lambda x, k: x[k], key.split("."), self.config_file)

def _get_item(self, key: str, default: R = Unknown) -> R:
def _get_item(self, key: str, default: t.Union[R, Unknown] = unknown) -> R:
try:
return self._get_item_from_environment(key)
except KeyError:
Expand All @@ -91,7 +91,7 @@ def _get_item(self, key: str, default: R = Unknown) -> R:
except KeyError:
...

if default is not Unknown:
if default is not unknown:
return t.cast(R, default)

raise KeyError(key)
Expand All @@ -116,27 +116,31 @@ def __call__(self, key: str) -> t.Any:
...

@t.overload
def __call__(self, key: str, *, default: R) -> R:
def __call__(self, key: str, *, default: t.Union[R, Unknown]) -> R:
...

@t.overload
def __call__(self, key: str, *, cast: t.Type[R]) -> R:
...

@t.overload
def __call__(self, key: str, *, default: R, cast: t.Type[R]) -> R:
def __call__(self, key: str, *, default: t.Union[R, Unknown], cast: t.Type[R]) -> R:
...

@t.overload
def __call__(self, key: str, *, cast: t.Callable[[t.Any], R]) -> R:
...

@t.overload
def __call__(self, key: str, *, default: R, cast: t.Callable[[t.Any], R]) -> R:
def __call__(self, key: str, *, default: t.Union[R, Unknown], cast: t.Callable[[t.Any], R]) -> R:
...

def __call__(
self, key: str, default: R = Unknown, cast: t.Optional[t.Union[t.Type[R], t.Callable[[t.Any], R]]] = None
self,
key: str,
*,
default: t.Union[R, Unknown] = unknown,
cast: t.Optional[t.Union[t.Type[R], t.Callable[[t.Any], R]]] = None
) -> R:
"""Get config parameter value.
Expand All @@ -162,8 +166,8 @@ def __call__(
if cast is None:
return value

if dataclasses.is_dataclass(cast) and inspect.isclass(cast):
return self._build_dataclass(data=value, dataclass=cast)
if dataclasses.is_dataclass(cast) and isinstance(cast, type):
return t.cast(R, self._build_dataclass(data=value, dataclass=cast))

try:
return t.cast(t.Callable[[t.Any], R], cast)(value)
Expand Down
2 changes: 1 addition & 1 deletion flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def default(self, o):
return o.__name__
if isinstance(o, BaseException):
return repr(o)
if dataclasses.is_dataclass(o):
if dataclasses.is_dataclass(o) and not isinstance(o, type):
return dataclasses.asdict(o)
return super().default(o)

Expand Down
5 changes: 3 additions & 2 deletions flama/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import inspect
import typing as t

import starlette.middleware.authentication
Expand Down Expand Up @@ -102,7 +101,9 @@ def __call__(self, app: "types.App") -> t.Union["types.MiddlewareClass", "types.
def __repr__(self) -> str:
name = self.__class__.__name__
middleware_name = (
self.middleware.__class__.__name__ if inspect.isclass(self.middleware) else self.middleware.__name__
self.middleware.__class__.__name__
if isinstance(self.middleware, types.MiddlewareClass)
else self.middleware.__name__
)
args = ", ".join([middleware_name] + [f"{key}={value!r}" for key, value in self.kwargs.items()])
return f"{name}({args})"
Expand Down
2 changes: 1 addition & 1 deletion flama/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def _websocket_endpoint(self, scope: types.Scope, receive: types.Receive,
"""
await self.handler(scope, receive, send)

def _build_api_response(self, handler: t.Callable, response: http.Response) -> http.Response:
def _build_api_response(self, handler: t.Callable, response: t.Union[http.Response, None]) -> http.Response:
"""Build an API response given a handler and the current response.
It infers the output schema from the handler signature or just wraps the response in a APIResponse object.
Expand Down
2 changes: 1 addition & 1 deletion flama/schemas/_libs/marshmallow/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def name(self, schema: t.Union[Schema, t.Type[Schema]], *, prefix: t.Optional[st
def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, Field]) -> JSONSchema:
json_schema: t.Dict[str, t.Any]
try:
plugin = MarshmallowPlugin(schema_name_resolver=lambda x: resolve_schema_cls(x).__name__)
plugin = MarshmallowPlugin(schema_name_resolver=lambda x: t.cast(type, resolve_schema_cls(x)).__name__)
APISpec("", "", "3.1.0", [plugin])
converter: "OpenAPIConverter" = t.cast("OpenAPIConverter", plugin.converter)

Expand Down
12 changes: 5 additions & 7 deletions flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from typing_extensions import TypeGuard

t.TypeGuard = TypeGuard # type: ignore
UnionType = [t.Union]
else:
UnionType = [t.Union, type(int | str)]

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover

Expand Down Expand Up @@ -87,12 +90,7 @@ def is_http_valid_type(cls, type_: t.Type) -> bool:

return (
(type_ in types.PARAMETERS_TYPES)
or (
origin in (t.Union, type(int | str))
and len(args) == 2
and args[0] in types.PARAMETERS_TYPES
and args[1] is NoneType
)
or (origin in UnionType and len(args) == 2 and args[0] in types.PARAMETERS_TYPES and args[1] is NoneType)
or (origin is list and args[0] in types.PARAMETERS_TYPES)
)

Expand Down Expand Up @@ -191,7 +189,7 @@ def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]:
if schemas.adapter.is_schema(schema):
return [schemas.adapter.unique_schema(schema)]

if t.get_origin(schema) in (t.Union, type(int | str)):
if t.get_origin(schema) in UnionType:
return [x for field in t.get_args(schema) for x in self.nested_schemas(field)]

if isinstance(schema, (list, tuple, set)):
Expand Down
4 changes: 2 additions & 2 deletions flama/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def add_link(self, name: str, item: t.Union[Link, Reference]):
def add_callback(self, name: str, item: t.Union[Callback, Reference]):
self.spec.components.callbacks[name] = item

def asdict(self, obj: t.Any = None) -> t.Any:
def asdict(self, obj=None) -> t.Any:
if obj is None:
return self.asdict(dataclasses.asdict(self.spec))

Expand All @@ -322,7 +322,7 @@ def asdict(self, obj: t.Any = None) -> t.Any:
if isinstance(obj, dict):
return {{"ref": "$ref", "in_": "in"}.get(k, k): self.asdict(v) for k, v in obj.items() if v is not None}

if dataclasses.is_dataclass(obj):
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
return self.asdict(dataclasses.asdict(obj))

return obj
2 changes: 1 addition & 1 deletion flama/types/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
Body = t.NewType("Body", bytes)
PathParams = t.NewType("PathParams", t.Dict[str, str])
PathParam = t.NewType("PathParam", str)
RequestData = t.NewType("RequestData", t.Any)
RequestData = t.NewType("RequestData", t.Dict[str, t.Any])
Headers = starlette.datastructures.Headers
MutableHeaders = starlette.datastructures.MutableHeaders
Cookies = t.NewType("Cookies", t.Dict[str, t.Dict[str, str]])
Expand Down
Loading

0 comments on commit 5bdfd0b

Please sign in to comment.