Skip to content

Commit

Permalink
🐛 Inherited response call methods
Browse files Browse the repository at this point in the history
  • Loading branch information
migduroli committed Oct 10, 2024
1 parent 0ba8173 commit 28a98a1
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import starlette.responses
import starlette.schemas

import flama.types
from flama import schemas, types
from flama.exceptions import HTTPException, SerializationError
from flama import exceptions, schemas, types

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover

Expand Down Expand Up @@ -59,11 +57,17 @@ async def __call__( # type: ignore[override]


class HTMLResponse(starlette.responses.HTMLResponse, Response):
__call__ = Response.__call__ # type: ignore[assignment]
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]


class PlainTextResponse(starlette.responses.PlainTextResponse, Response):
__call__ = Response.__call__ # type: ignore[assignment]
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]


class EnhancedJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -107,6 +111,11 @@ def default(self, o):


class JSONResponse(starlette.responses.JSONResponse, Response):
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]

def render(self, content: t.Any) -> bytes:
if isinstance(content, types.Schema):
content = dict(content)
Expand All @@ -117,15 +126,24 @@ def render(self, content: t.Any) -> bytes:


class RedirectResponse(starlette.responses.RedirectResponse, Response):
__call__ = Response.__call__ # type: ignore[assignment]
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]


class StreamingResponse(starlette.responses.StreamingResponse, Response):
__call__ = Response.__call__ # type: ignore[assignment]
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]


class FileResponse(starlette.responses.FileResponse, Response):
__call__ = Response.__call__ # type: ignore[assignment]
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]


class APIResponse(JSONResponse):
Expand All @@ -140,7 +158,7 @@ def render(self, content: t.Any):
try:
content = schemas.Schema.from_type(self.schema).dump(content)
except schemas.SchemaValidationError as e:
raise SerializationError(status_code=500, detail=e.errors)
raise exceptions.SerializationError(status_code=500, detail=e.errors)

if not content:
return b""
Expand Down Expand Up @@ -179,7 +197,7 @@ def __init__(self, path: str, *args, **kwargs):
with open(path) as f:
content = f.read()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise exceptions.HTTPException(status_code=500, detail=str(e))

super().__init__(content, *args, **kwargs)

Expand Down Expand Up @@ -213,7 +231,7 @@ def __init__(self, *args, **kwargs):

self.filters["safe_json"] = self.safe_json

def _escape(self, value: flama.types.JSONField) -> flama.types.JSONField:
def _escape(self, value: types.JSONField) -> types.JSONField:
if isinstance(value, (list, tuple)):
return [self._escape(x) for x in value]

Expand All @@ -225,7 +243,7 @@ def _escape(self, value: flama.types.JSONField) -> flama.types.JSONField:

return value

def safe_json(self, value: flama.types.JSONField):
def safe_json(self, value: types.JSONField):
return json.dumps(self._escape(value)).replace('"', '\\"')


Expand All @@ -237,7 +255,12 @@ class _ReactTemplateResponse(HTMLTemplateResponse):
)


class OpenAPIResponse(Response, starlette.schemas.OpenAPIResponse):
class OpenAPIResponse(starlette.schemas.OpenAPIResponse, Response):
async def __call__( # type: ignore[override]
self, scope: types.Scope, receive: types.Receive, send: types.Send
) -> None:
await super().__call__(scope, receive, send) # type: ignore[arg-type]

def render(self, content: t.Any) -> bytes:
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."

Expand Down

0 comments on commit 28a98a1

Please sign in to comment.