diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 17d6b3e41..4e9f1b70a 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: "actions/checkout@v4" diff --git a/requirements.txt b/requirements.txt index 01c2016c1..a6078076f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ types-PyYAML==6.0.12.20250402 types-dataclasses==0.6.6 pytest==8.3.5 trio==0.30.0 +exceptiongroup; python_version<'3.11' # Documentation black==25.1.0 diff --git a/starlette/testclient.py b/starlette/testclient.py index d54025e52..ba10a72e7 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -68,6 +68,11 @@ class _AsyncBackend(typing.TypedDict): backend_options: dict[str, typing.Any] +class _AsyncUpgrade(Exception): + def __init__(self, session: AsyncWebSocketTestSession) -> None: + self.session = session + + class _Upgrade(Exception): def __init__(self, session: WebSocketTestSession) -> None: self.session = session @@ -729,3 +734,639 @@ async def receive() -> typing.Any: ) if message["type"] == "lifespan.shutdown.failed": await receive() + + +class AsyncWebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.extra_headers = None + + async def __aenter__(self) -> AsyncWebSocketTestSession: + async with contextlib.AsyncExitStack() as stack: + task_group = await stack.enter_async_context(anyio.create_task_group()) + self.done = anyio.Event() + + async def run(*, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + await self._run(task_status=task_status) + self.done.set() + + await task_group.start(run) + stack.push_async_callback(self.done.wait) + stack.callback(task_group.cancel_scope.cancel) + await self.send({"type": "websocket.connect"}) + message = await self.receive() + await self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.push_async_callback(self.aclose, 1000) + self.exit_stack = stack.pop_all() + return self + + async def __aexit__(self, *args: typing.Any) -> bool | None: + return await self.exit_stack.__aexit__(*args) + + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + send_tx, send_rx = send + receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + receive_tx, receive_rx = receive + with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: + self._receive_tx = receive_tx + self._send_rx = send_rx + task_status.started(cs) + await self.app(self.scope, receive_rx.receive, send_tx.send) + + # wait for cs.cancel to be called before closing streams + await anyio.sleep_forever() + + async def _raise_on_close(self, message: Message) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) + elif message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = await self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) + + async def send(self, message: Message) -> None: + await self._receive_tx.send(message) + + async def send_text(self, data: str) -> None: + await self.send({"type": "websocket.receive", "text": data}) + + async def send_bytes(self, data: bytes) -> None: + await self.send({"type": "websocket.receive", "bytes": data}) + + async def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + await self.send({"type": "websocket.receive", "text": text}) + else: + await self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + async def aclose(self, code: int = 1000, reason: str | None = None) -> None: + await self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) + + async def receive(self) -> Message: + return await self._send_rx.receive() + + async def receive_text(self) -> str: + message = await self.receive() + await self._raise_on_close(message) + return typing.cast(str, message["text"]) + + async def receive_bytes(self) -> bytes: + message = await self.receive() + await self._raise_on_close(message) + return typing.cast(bytes, message["bytes"]) + + async def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: + message = await self.receive() + await self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + +class _AsyncTestClientTransport(httpx.AsyncBaseTransport): + def __init__( + self, + app: ASGI3App, + raise_server_exceptions: bool = True, + root_path: str = "", + *, + client: tuple[str, int], + app_state: dict[str, typing.Any], + ) -> None: + self.app = app + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + self.app_state = app_state + self.client = client + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = request.url.netloc.decode(encoding="ascii") + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers: list[tuple[bytes, bytes]] = [] + elif port == default_port: # pragma: no cover + headers = [(b"host", host.encode())] + else: # pragma: no cover + headers = [(b"host", (f"{host}:{port}").encode())] + + # Include other request headers. + headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] + + scope: dict[str, typing.Any] + + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols: typing.Sequence[str] = [] + else: + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "subprotocols": subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + session = AsyncWebSocketTestSession(self.app, scope) + raise _AsyncUpgrade(session) + + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "extensions": {"http.response.debug": {}}, + "state": self.app_state.copy(), + } + + request_complete = False + response_started = False + response_complete: anyio.Event + raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} + template = None + context = None + + async def receive() -> Message: + nonlocal request_complete + + if request_complete: + if not response_complete.is_set(): + await response_complete.wait() + return {"type": "http.disconnect"} + + body = request.read() + if isinstance(body, str): + body_bytes: bytes = body.encode("utf-8") # pragma: no cover + elif body is None: + body_bytes = b"" # pragma: no cover + elif isinstance(body, GeneratorType): + try: # pragma: no cover + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: # pragma: no cover + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, template, context + + if message["type"] == "http.response.start": + assert not response_started, 'Received multiple "http.response.start" messages.' + raw_kwargs["status_code"] = message["status"] + raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] + response_started = True + elif message["type"] == "http.response.body": + assert response_started, 'Received "http.response.body" without "http.response.start".' + assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["stream"].write(body) + if not more_body: + raw_kwargs["stream"].seek(0) + response_complete.set() + elif message["type"] == "http.response.debug": + template = message["info"]["template"] + context = message["info"]["context"] + + try: + response_complete = anyio.Event() + await self.app(scope, receive, send) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "status_code": 500, + "headers": [], + "stream": io.BytesIO(), + } + + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) + if template is not None: + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] + return response + + +class AsyncTestClient(httpx.AsyncClient): + __test__ = False + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> None: + if _is_asgi3(app): + asgi_app = app + else: + app = typing.cast(ASGI2App, app) # type: ignore[assignment] + asgi_app = _WrapASGI2(app) # type: ignore[arg-type] + self.app = asgi_app + self.app_state: dict[str, typing.Any] = {} + transport = _AsyncTestClientTransport( + self.app, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + app_state=self.app_state, + client=client, + ) + if headers is None: + headers = {} + headers.setdefault("user-agent", "testclient") + super().__init__( + base_url=base_url, + headers=headers, + transport=transport, + follow_redirects=follow_redirects, + cookies=cookies, + ) + + async def request( # type: ignore[override] + self, + method: str, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + if timeout is not httpx.USE_CLIENT_DEFAULT: + warnings.warn( + "You should not use the 'timeout' argument with the TestClient. " + "See https://github.com/encode/starlette/issues/1108 for more information.", + DeprecationWarning, + ) + url = self._merge_url(url) + return await super().request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().head( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().post( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().put( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().patch( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def websocket_connect( + self, + url: str, + subprotocols: typing.Sequence[str] | None = None, + **kwargs: typing.Any, + ) -> AsyncWebSocketTestSession: + url = urljoin("ws://testserver", url) + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + kwargs["headers"] = headers + try: + await super().request("GET", url, **kwargs) + except _AsyncUpgrade as exc: + session = exc.session + else: + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + return session + + async def __aenter__(self) -> AsyncTestClient: + async with contextlib.AsyncExitStack() as stack: + task_group = await stack.enter_async_context(anyio.create_task_group()) + send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( + anyio.create_memory_object_stream(math.inf) + ) + receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( + anyio.create_memory_object_stream(math.inf) + ) + for channel in (*send, *receive): + stack.push_async_callback(channel.aclose) + self.stream_send = StapledObjectStream(*send) + self.stream_receive = StapledObjectStream(*receive) + self.task_done = anyio.Event() + + async def lifespan() -> None: + await self.lifespan() + self.task_done.set() + + task_group.start_soon(lifespan) + await self.wait_startup() + + @stack.push_async_callback + async def wait_shutdown() -> None: + await self.wait_shutdown() + + self.exit_stack = stack.pop_all() + + return self + + async def __aexit__(self, *args: typing.Any) -> None: + await self.exit_stack.aclose() + + async def lifespan(self) -> None: + scope = {"type": "lifespan", "state": self.app_state} + try: + await self.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + try: + await self.stream_send.send(None) + except anyio.ClosedResourceError: + pass + + async def wait_startup(self) -> None: + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + await self.task_done.wait() + return message + + message = await receive() + assert message["type"] in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ) + if message["type"] == "lifespan.startup.failed": + await receive() + + async def wait_shutdown(self) -> None: + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + await self.task_done.wait() + return message + + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/tests/conftest.py b/tests/conftest.py index 4db3ae018..87412e66e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,15 @@ import pytest -from starlette.testclient import TestClient -from tests.types import TestClientFactory +from starlette.testclient import AsyncTestClient, TestClient +from tests.types import AsyncTestClientFactory, TestClientFactory + + +@pytest.fixture +def async_test_client_factory() -> AsyncTestClientFactory: + return functools.partial( + AsyncTestClient, + ) @pytest.fixture diff --git a/tests/test_asynctestclient.py b/tests/test_asynctestclient.py new file mode 100644 index 000000000..e5d7b210b --- /dev/null +++ b/tests/test_asynctestclient.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import itertools +import sys +from asyncio import Task, current_task as asyncio_current_task +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel +import pytest +import sniffio +import trio.lowlevel + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.routing import Route +from starlette.testclient import ASGIInstance, AsyncTestClient +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketDisconnect +from tests.types import AsyncTestClientFactory + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +def mock_service_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"mock": "example"}) + + +mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)]) + + +def current_task() -> Task[Any] | trio.lowlevel.Task: + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() + + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover + + +def startup() -> None: + raise RuntimeError() + + +@pytest.mark.anyio +async def test_use_testclient_in_endpoint(async_test_client_factory: AsyncTestClientFactory) -> None: + """ + We should be able to use the test client within applications. + + This is useful if we need to mock out other services, + during tests or in development. + """ + + async def homepage(request: Request) -> JSONResponse: + client = async_test_client_factory(mock_service) + response = await client.get("/") + return JSONResponse(response.json()) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.json() == {"mock": "example"} + + +def test_testclient_headers_behavior() -> None: + """ + We should be able to use the test client with user defined headers. + + This is useful if we need to set custom headers for authentication + during tests or in development. + """ + + client = AsyncTestClient(mock_service) + assert client.headers.get("user-agent") == "testclient" + + client = AsyncTestClient(mock_service, headers={"user-agent": "non-default-agent"}) + assert client.headers.get("user-agent") == "non-default-agent" + + client = AsyncTestClient(mock_service, headers={"Authentication": "Bearer 123"}) + assert client.headers.get("user-agent") == "testclient" + assert client.headers.get("Authentication") == "Bearer 123" + + +async def test_use_testclient_as_contextmanager( + async_test_client_factory: AsyncTestClientFactory, anyio_backend_name: str +) -> None: + """ + This test asserts a number of properties that are important for an + app level task_group + """ + counter = itertools.count() + identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") + + def get_identity() -> int: + try: + return identity_runvar.get() + except LookupError: + token = next(counter) + identity_runvar.set(token) + return token + + startup_task = object() + startup_loop = None + shutdown_task = object() + shutdown_loop = None + + @asynccontextmanager + async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]: + nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop + + startup_task = current_task() + startup_loop = get_identity() + async with anyio.create_task_group(): + yield + shutdown_task = current_task() + shutdown_loop = get_identity() + + async def loop_id(request: Request) -> JSONResponse: + return JSONResponse(get_identity()) + + app = Starlette( + lifespan=lifespan_context, + routes=[Route("/loop_id", endpoint=loop_id)], + ) + + client = async_test_client_factory(app) + + async with client: + # within a TestClient context every async request runs in the same thread + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + # that thread is also the same as the lifespan thread + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert startup_task is shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # event loops in new threads + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + first_task = startup_task + + async with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events still run in the same task, with the context but... + assert startup_task is shutdown_task + + # ... the second TestClient context creates a new lifespan task. + assert first_task is not startup_task + + +@pytest.mark.anyio +async def test_error_on_startup(async_test_client_factory: AsyncTestClientFactory) -> None: + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): + startup_error_app = Starlette(on_startup=[startup]) + + with pytest.raises(ExceptionGroup) as excinfo: + async with async_test_client_factory(startup_error_app): + pass # pragma: no cover + + assert excinfo.group_contains(RuntimeError) + + +@pytest.mark.anyio +async def test_exception_in_middleware(async_test_client_factory: AsyncTestClientFactory) -> None: + class MiddlewareException(Exception): + pass + + class BrokenMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + raise MiddlewareException() + + broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) + + with pytest.raises(ExceptionGroup) as excinfo: + async with async_test_client_factory(broken_middleware): + pass # pragma: no cover + + assert excinfo.group_contains(MiddlewareException) + + +@pytest.mark.anyio +async def test_testclient_asgi2(async_test_client_factory: AsyncTestClientFactory) -> None: + def app(scope: Scope) -> ASGIInstance: + async def inner(receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + return inner + + client = async_test_client_factory(app) # type: ignore + response = await client.get("/") + assert response.text == "Hello, world!" + + +@pytest.mark.anyio +async def test_testclient_asgi3(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.text == "Hello, world!" + + +@pytest.mark.anyio +async def test_websocket_blocking_receive(async_test_client_factory: AsyncTestClientFactory) -> None: + def app(scope: Scope) -> ASGIInstance: + async def respond(websocket: WebSocket) -> None: + await websocket.send_json({"message": "test"}) + + async def asgi(receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async with anyio.create_task_group() as task_group: + task_group.start_soon(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass + + return asgi + + client = async_test_client_factory(app) # type: ignore + async with await client.websocket_connect("/") as websocket: + data = await websocket.receive_json() + assert data == {"message": "test"} + + +@pytest.mark.anyio +async def test_websocket_not_block_on_close(async_test_client_factory: AsyncTestClientFactory) -> None: + cancelled = False + + def app(scope: Scope) -> ASGIInstance: + async def asgi(receive: Receive, send: Send) -> None: + nonlocal cancelled + try: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await anyio.sleep_forever() + except anyio.get_cancelled_exc_class(): + cancelled = True + raise + + return asgi + + client = async_test_client_factory(app) # type: ignore + async with await client.websocket_connect("/"): + ... + assert cancelled + + +@pytest.mark.anyio +async def test_client(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + client = scope.get("client") + assert client is not None + host, port = client + response = JSONResponse({"host": host, "port": port}) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.json() == {"host": "testclient", "port": 50000} + + +@pytest.mark.anyio +async def test_client_custom_client(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + client = scope.get("client") + assert client is not None + host, port = client + response = JSONResponse({"host": host, "port": port}) + await response(scope, receive, send) + + client = async_test_client_factory(app, client=("192.168.0.1", 3000)) + response = await client.get("/") + assert response.json() == {"host": "192.168.0.1", "port": 3000} + + +@pytest.mark.anyio +@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) +async def test_query_params(async_test_client_factory: AsyncTestClientFactory, param: str) -> None: + def homepage(request: Request) -> Response: + return Response(request.query_params["param"]) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + client = async_test_client_factory(app) + response = await client.get("/", params={"param": param}) + assert response.text == param + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "domain, ok", + [ + pytest.param( + "testserver", + True, + marks=[ + pytest.mark.xfail( + sys.version_info < (3, 11), + reason="Fails due to domain handling in http.cookiejar module (see #2152)", + ), + ], + ), + ("testserver.local", True), + ("localhost", False), + ("example.com", False), + ], +) +async def test_domain_restricted_cookies( + async_test_client_factory: AsyncTestClientFactory, domain: str, ok: bool +) -> None: + """ + Test that test client discards domain restricted cookies which do not match the + base_url of the testclient (`http://testserver` by default). + + The domain `testserver.local` works because the Python http.cookiejar module derives + the "effective domain" by appending `.local` to non-dotted request domains + in accordance with RFC 2965. + """ + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response("Hello, world!", media_type="text/plain") + response.set_cookie( + "mycookie", + "myvalue", + path="/", + domain=domain, + ) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/") + cookie_set = len(response.cookies) == 1 + assert cookie_set == ok + + +@pytest.mark.anyio +async def test_forward_follow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + if "/ok" in scope["path"]: + response = Response("ok") + else: + response = RedirectResponse("/ok") + await response(scope, receive, send) + + client = async_test_client_factory(app, follow_redirects=True) + response = await client.get("/") + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_forward_nofollow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = RedirectResponse("/ok") + await response(scope, receive, send) + + client = async_test_client_factory(app, follow_redirects=False) + response = await client.get("/") + assert response.status_code == 307 + + +@pytest.mark.anyio +async def test_with_duplicate_headers(async_test_client_factory: AsyncTestClientFactory) -> None: + def homepage(request: Request) -> JSONResponse: + return JSONResponse({"x-token": request.headers.getlist("x-token")}) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + client = async_test_client_factory(app) + response = await client.get("/", headers=[("x-token", "foo"), ("x-token", "bar")]) + assert response.json() == {"x-token": ["foo", "bar"]} + + +@pytest.mark.anyio +async def test_merge_url(async_test_client_factory: AsyncTestClientFactory) -> None: + def homepage(request: Request) -> Response: + return Response(request.url.path) + + app = Starlette(routes=[Route("/api/v1/bar", endpoint=homepage)]) + client = async_test_client_factory(app, base_url="http://testserver/api/v1/") + response = await client.get("/bar") + assert response.text == "/api/v1/bar" + + +@pytest.mark.anyio +async def test_raw_path_with_querystring(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response(scope.get("raw_path")) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/hello-world", params={"foo": "bar"}) + assert response.content == b"/hello-world" + + +@pytest.mark.anyio +async def test_websocket_raw_path_without_params(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + raw_path = scope.get("raw_path") + assert raw_path is not None + await websocket.send_bytes(raw_path) + + client = async_test_client_factory(app) + async with await client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket: + data = await websocket.receive_bytes() + assert data == b"/hello-world" + + +@pytest.mark.anyio +async def test_timeout_deprecation() -> None: + with pytest.deprecated_call(match="You should not use the 'timeout' argument with the TestClient."): + client = AsyncTestClient(mock_service) + await client.get("/", timeout=1) diff --git a/tests/types.py b/tests/types.py index e4769d308..f0a3b44ed 100644 --- a/tests/types.py +++ b/tests/types.py @@ -4,11 +4,24 @@ import httpx -from starlette.testclient import TestClient +from starlette.testclient import AsyncTestClient, TestClient from starlette.types import ASGIApp if TYPE_CHECKING: + class AsyncTestClientFactory(Protocol): # pragma: no cover + def __call__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> AsyncTestClient: ... + class TestClientFactory(Protocol): # pragma: no cover def __call__( self, @@ -23,5 +36,8 @@ def __call__( ) -> TestClient: ... else: # pragma: no cover + class AsyncTestClientFactory: + __test__ = False + class TestClientFactory: __test__ = False