diff --git a/flama/debug/middleware.py b/flama/debug/middleware.py index 0d3876ef..7a80d539 100644 --- a/flama/debug/middleware.py +++ b/flama/debug/middleware.py @@ -82,7 +82,7 @@ def debug_handler( accept = request.headers.get("accept", "") if "text/html" in accept: - return http._ReactTemplateResponse( + return http._FlamaTemplateResponse( "debug/error_500.html", context=dataclasses.asdict(ErrorContext.build(request, exc)), status_code=500 ) return http.PlainTextResponse("Internal Server Error", status_code=500) @@ -100,10 +100,10 @@ class ExceptionMiddleware(BaseErrorMiddleware): def __init__(self, app: types.App, handlers: t.Optional[t.Mapping[t.Any, "Handler"]] = None, debug: bool = False): super().__init__(app, debug) handlers = handlers or {} - self._status_handlers: dict[int, "Handler"] = { + self._status_handlers: dict[int, Handler] = { status_code: handler for status_code, handler in handlers.items() if isinstance(status_code, int) } - self._exception_handlers: dict[type[Exception], "Handler"] = { + self._exception_handlers: dict[type[Exception], Handler] = { **{e: handler for e, handler in handlers.items() if inspect.isclass(e) and issubclass(e, Exception)}, exceptions.NotFoundException: self.not_found_handler, exceptions.MethodNotAllowedException: self.method_not_allowed_handler, @@ -164,7 +164,7 @@ def http_exception_handler( accept = request.headers.get("accept", "") if self.debug and exc.status_code == 404 and "text/html" in accept: - return http._ReactTemplateResponse( + return http._FlamaTemplateResponse( template="debug/error_404.html", context=dataclasses.asdict(NotFoundContext.build(request, scope["app"])), status_code=404, diff --git a/flama/http.py b/flama/http.py index 281351f0..c6ade138 100644 --- a/flama/http.py +++ b/flama/http.py @@ -2,12 +2,14 @@ import datetime import enum import html +import importlib.util import inspect import json import os +import pathlib import typing as t import uuid -from pathlib import Path +import warnings import jinja2 import starlette.requests @@ -29,6 +31,7 @@ "APIResponse", "APIErrorResponse", "HTMLFileResponse", + "HTMLTemplatesEnvironment", "HTMLTemplateResponse", "OpenAPIResponse", ] @@ -62,7 +65,7 @@ async def __call__( # type: ignore[override] class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): - if isinstance(o, (Path, os.PathLike, uuid.UUID)): + if isinstance(o, (pathlib.Path, os.PathLike, uuid.UUID)): return str(o) if isinstance(o, (bytes, bytearray)): return o.decode("utf-8") @@ -191,21 +194,7 @@ def __init__(self, path: str, *args, **kwargs): super().__init__(content, *args, **kwargs) -class HTMLTemplateResponse(HTMLResponse): - templates = jinja2.Environment( - loader=jinja2.ChoiceLoader( - [jinja2.FileSystemLoader(Path(os.curdir) / "templates"), jinja2.PackageLoader("flama", "templates")] - ) - ) - - def __init__(self, template: str, context: t.Optional[dict[str, t.Any]] = None, *args, **kwargs): - if context is None: - context = {} - - super().__init__(self.templates.get_template(template).render(**context), *args, **kwargs) - - -class _ReactTemplatesEnvironment(jinja2.Environment): +class HTMLTemplatesEnvironment(jinja2.Environment): def __init__(self, *args, **kwargs): super().__init__( *args, @@ -236,12 +225,32 @@ def safe_json(self, value: types.JSONField): return json.dumps(self._escape(value)).replace('"', '\\"') -class _ReactTemplateResponse(HTMLTemplateResponse): - templates = _ReactTemplatesEnvironment( - loader=jinja2.ChoiceLoader( - [jinja2.FileSystemLoader(Path(os.curdir) / "templates"), jinja2.PackageLoader("flama", "templates")] - ) - ) +class HTMLTemplateResponse(HTMLResponse): + templates = HTMLTemplatesEnvironment(loader=jinja2.FileSystemLoader(pathlib.Path(os.curdir) / "templates")) + + def __init__(self, template: str, context: t.Optional[dict[str, t.Any]] = None, *args, **kwargs): + if context is None: + context = {} + + super().__init__(self.templates.get_template(template).render(**context), *args, **kwargs) + + +class _FlamaLoader(jinja2.PackageLoader): + def __init__(self): + spec = importlib.util.find_spec("flama") + if spec is None or spec.origin is None: + raise exceptions.ApplicationError("Flama package not found.") + + templates_path = pathlib.Path(spec.origin).parent.joinpath("templates") + if not templates_path.exists(): + warnings.warn("Templates folder not found in the Flama package") + templates_path.mkdir(exist_ok=True) + + super().__init__(package_name="flama", package_path="templates") + + +class _FlamaTemplateResponse(HTMLTemplateResponse): + templates = HTMLTemplatesEnvironment(loader=_FlamaLoader()) class OpenAPIResponse(starlette.schemas.OpenAPIResponse, Response): diff --git a/flama/schemas/modules.py b/flama/schemas/modules.py index e25adbdf..e803f943 100644 --- a/flama/schemas/modules.py +++ b/flama/schemas/modules.py @@ -88,6 +88,6 @@ def schema_view(self) -> http.OpenAPIResponse: return http.OpenAPIResponse(self.schema) def docs_view(self) -> http.HTMLResponse: - return http._ReactTemplateResponse( + return http._FlamaTemplateResponse( "schemas/docs.html", {"title": self.title, "schema_url": self.schema_path, "docs_url": self.docs_path} ) diff --git a/tests/debug/test_middleware.py b/tests/debug/test_middleware.py index 0e04ca85..f0c7c386 100644 --- a/tests/debug/test_middleware.py +++ b/tests/debug/test_middleware.py @@ -119,12 +119,12 @@ def test_debug_response_html(self, middleware, asgi_scope, asgi_receive, asgi_se with patch( "flama.debug.middleware.dataclasses.asdict", return_value=context_mock ) as dataclasses_dict, patch.object(ErrorContext, "build", return_value=error_context_mock), patch.object( - http._ReactTemplateResponse, "__init__", return_value=None + http._FlamaTemplateResponse, "__init__", return_value=None ) as response_mock: response = middleware.debug_handler(asgi_scope, asgi_receive, asgi_send, exc) assert ErrorContext.build.call_count == 1 assert dataclasses_dict.call_args_list == [call(error_context_mock)] - assert isinstance(response, http._ReactTemplateResponse) + assert isinstance(response, http._FlamaTemplateResponse) assert response_mock.call_args_list == [call("debug/error_500.html", context=context_mock, status_code=500)] def test_debug_response_text(self, middleware, asgi_scope, asgi_receive, asgi_send): @@ -290,7 +290,7 @@ async def test_process_exception( True, b"text/html", exceptions.HTTPException(404, "Foo"), - http._ReactTemplateResponse, + http._FlamaTemplateResponse, {"template": "debug/error_404.html", "context": {}, "status_code": 404}, id="debug_404", ), diff --git a/tests/test_http.py b/tests/test_http.py index b13d5ea4..1a608486 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -170,43 +170,10 @@ def test_init_error(self): assert exc.detail == error_detail -class TestCaseOpenAPIResponse: - @pytest.mark.parametrize( - "test_input,expected,exception", - ( - pytest.param({"foo": "bar"}, {"foo": "bar"}, None, id="success"), - pytest.param("foo", None, AssertionError, id="wrong_content"), - ), - indirect=("exception",), - ) - def test_render(self, test_input, expected, exception): - with exception: - response = http.OpenAPIResponse(test_input) - - assert json.loads(response.body.decode()) == expected - - -class TestCaseHTMLTemplateResponse: - @pytest.mark.parametrize( - ["context"], (pytest.param({"foo": "bar"}, id="context"), pytest.param(None, id="no_context")) - ) - def test_init(self, context): - template_mock = MagicMock() - template_mock.render.return_value = "foo" - environment_mock = MagicMock(spec=jinja2.Environment) - environment_mock.get_template.return_value = template_mock - with patch.object(http.HTMLTemplateResponse, "templates", new=environment_mock), patch.object( - http.HTMLResponse, "__init__", return_value=None - ) as super_mock: - http.HTMLTemplateResponse("foo.html", context) - - assert super_mock.call_args_list == [call("foo")] - - -class TestCaseReactTemplatesEnvironment: +class TestCaseHTMLTemplatesEnvironment: @pytest.fixture def environment(self): - return http._ReactTemplatesEnvironment() + return http.HTMLTemplatesEnvironment() @pytest.mark.parametrize( ["value", "result"], @@ -294,3 +261,36 @@ def test_escape(self, environment, value, result): ) def test_safe_json(self, environment, value, result): assert environment.safe_json(value) == result + + +class TestCaseHTMLTemplateResponse: + @pytest.mark.parametrize( + ["context"], (pytest.param({"foo": "bar"}, id="context"), pytest.param(None, id="no_context")) + ) + def test_init(self, context): + template_mock = MagicMock() + template_mock.render.return_value = "foo" + environment_mock = MagicMock(spec=jinja2.Environment) + environment_mock.get_template.return_value = template_mock + with patch.object(http.HTMLTemplateResponse, "templates", new=environment_mock), patch.object( + http.HTMLResponse, "__init__", return_value=None + ) as super_mock: + http.HTMLTemplateResponse("foo.html", context) + + assert super_mock.call_args_list == [call("foo")] + + +class TestCaseOpenAPIResponse: + @pytest.mark.parametrize( + "test_input,expected,exception", + ( + pytest.param({"foo": "bar"}, {"foo": "bar"}, None, id="success"), + pytest.param("foo", None, AssertionError, id="wrong_content"), + ), + indirect=("exception",), + ) + def test_render(self, test_input, expected, exception): + with exception: + response = http.OpenAPIResponse(test_input) + + assert json.loads(response.body.decode()) == expected