diff --git a/flama/http.py b/flama/http.py index 649dbe47..79705154 100644 --- a/flama/http.py +++ b/flama/http.py @@ -15,9 +15,7 @@ import starlette.responses import starlette.schemas -import flama.types -from flama import schemas, types -from flama.exceptions import HTTPException, SerializationError +from flama import exceptions, schemas, types if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover @@ -59,11 +57,17 @@ async def __call__( # type: ignore[override] class HTMLResponse(starlette.responses.HTMLResponse, Response): - __call__ = Response.__call__ # type: ignore[assignment] + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] class PlainTextResponse(starlette.responses.PlainTextResponse, Response): - __call__ = Response.__call__ # type: ignore[assignment] + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] class EnhancedJSONEncoder(json.JSONEncoder): @@ -107,6 +111,11 @@ def default(self, o): class JSONResponse(starlette.responses.JSONResponse, Response): + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] + def render(self, content: t.Any) -> bytes: if isinstance(content, types.Schema): content = dict(content) @@ -117,15 +126,24 @@ def render(self, content: t.Any) -> bytes: class RedirectResponse(starlette.responses.RedirectResponse, Response): - __call__ = Response.__call__ # type: ignore[assignment] + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] class StreamingResponse(starlette.responses.StreamingResponse, Response): - __call__ = Response.__call__ # type: ignore[assignment] + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] class FileResponse(starlette.responses.FileResponse, Response): - __call__ = Response.__call__ # type: ignore[assignment] + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] class APIResponse(JSONResponse): @@ -140,7 +158,7 @@ def render(self, content: t.Any): try: content = schemas.Schema.from_type(self.schema).dump(content) except schemas.SchemaValidationError as e: - raise SerializationError(status_code=500, detail=e.errors) + raise exceptions.SerializationError(status_code=500, detail=e.errors) if not content: return b"" @@ -179,7 +197,7 @@ def __init__(self, path: str, *args, **kwargs): with open(path) as f: content = f.read() except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise exceptions.HTTPException(status_code=500, detail=str(e)) super().__init__(content, *args, **kwargs) @@ -213,7 +231,7 @@ def __init__(self, *args, **kwargs): self.filters["safe_json"] = self.safe_json - def _escape(self, value: flama.types.JSONField) -> flama.types.JSONField: + def _escape(self, value: types.JSONField) -> types.JSONField: if isinstance(value, (list, tuple)): return [self._escape(x) for x in value] @@ -225,7 +243,7 @@ def _escape(self, value: flama.types.JSONField) -> flama.types.JSONField: return value - def safe_json(self, value: flama.types.JSONField): + def safe_json(self, value: types.JSONField): return json.dumps(self._escape(value)).replace('"', '\\"') @@ -237,7 +255,12 @@ class _ReactTemplateResponse(HTMLTemplateResponse): ) -class OpenAPIResponse(Response, starlette.schemas.OpenAPIResponse): +class OpenAPIResponse(starlette.schemas.OpenAPIResponse, Response): + async def __call__( # type: ignore[override] + self, scope: types.Scope, receive: types.Receive, send: types.Send + ) -> None: + await super().__call__(scope, receive, send) # type: ignore[arg-type] + def render(self, content: t.Any) -> bytes: assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."