Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use less private imports on the testclient module #2709

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 81 additions & 75 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
import warnings
from concurrent.futures import Future
from functools import cached_property
from http.cookiejar import CookieJar
from types import GeneratorType
from urllib.parse import unquote, urljoin

@@ -25,9 +26,9 @@
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
from typing import TypeAlias, TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard
from typing_extensions import TypeAlias, TypeGuard

try:
import httpx
@@ -37,16 +38,29 @@
"You can install this with:\n"
" $ pip install httpx\n"
)


Auth: TypeAlias = "httpx.Auth | tuple[str | bytes, str | bytes]"
QueryParams: TypeAlias = "httpx.QueryParams | typing.Mapping[str, str] | None"
Cookies: TypeAlias = "httpx.Cookies | CookieJar | dict[str, str] | list[tuple[str, str]] | None"
URL: TypeAlias = "httpx.URL | str"
Timeout: TypeAlias = "float | httpx.Timeout | tuple[float | None, float | None, float | None, float | None] | None"
Headers = typing.Union[
httpx.Headers,
typing.Mapping[bytes, bytes],
typing.Mapping[str, str],
typing.Sequence[typing.Tuple[bytes, bytes]],
typing.Sequence[typing.Tuple[str, str]],
None,
]
Comment on lines +43 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pretty much have to do what httpx is doing on the _client module because I think we want to support the type possibilities that are there...


_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]

ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]


_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]


def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
@@ -76,23 +90,15 @@ def __init__(self, session: WebSocketTestSession) -> None:
self.session = session


class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
class WebSocketDenialResponse(httpx.Response, WebSocketDisconnect): # type: ignore[misc]
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""


class WebSocketTestSession:
def __init__(
self,
app: ASGI3App,
scope: Scope,
portal_factory: _PortalFactoryType,
) -> None:
def __init__(self, app: ASGI3App, scope: Scope, portal_factory: _PortalFactoryType) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
@@ -468,19 +474,19 @@ def _choose_redirect_arg(
def request( # type: ignore[override]
self,
method: str,
url: httpx._types.URLTypes,
url: URL,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
content: typing.Any = None,
data: typing.Any = None,
files: typing.Any = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly...

UseClientDefault = type(httpx.USE_CLIENT_DEFAULT)

...

auth: Auth | UseClientDefault = USE_CLIENT_DEFAULT

(I don't much like this style and I'd enjoy cleaning it up, tho would that be sufficient for now?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it works for the type checkers.. Does it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work.

follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
url = self._merge_url(url)
@@ -503,15 +509,15 @@ def request( # type: ignore[override]

def get( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -528,15 +534,15 @@ def get( # type: ignore[override]

def options( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -553,15 +559,15 @@ def options( # type: ignore[override]

def head( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -578,19 +584,19 @@ def head( # type: ignore[override]

def post( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
content: typing.Any = None,
data: typing.Any = None,
files: typing.Any = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -611,19 +617,19 @@ def post( # type: ignore[override]

def put( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
content: typing.Any = None,
data: typing.Any = None,
files: typing.Any = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -644,19 +650,19 @@ def put( # type: ignore[override]

def patch( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
content: typing.Any = None,
data: typing.Any = None,
files: typing.Any = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -677,15 +683,15 @@ def patch( # type: ignore[override]

def delete( # type: ignore[override]
self,
url: httpx._types.URLTypes,
url: URL | str,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
params: QueryParams | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
auth: Auth | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: Timeout | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
54 changes: 13 additions & 41 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
@@ -275,7 +275,7 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
"/",
data=(
# data
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; name="field0"\r\n\r\n'
b"value0\r\n"
# file
@@ -309,7 +309,7 @@ def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_f
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"<file content>\r\n"
@@ -333,7 +333,7 @@ def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_clien
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'
b"Content-Type: image/jpeg\r\n\r\n"
b"<file content>\r\n"
@@ -356,7 +356,7 @@ def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory:
response = client.post(
"/",
data=(
b"--20b303e711c4ab8c443184ac833ab00f\r\n" # type: ignore
b"--20b303e711c4ab8c443184ac833ab00f\r\n"
b"Content-Disposition: form-data; "
b'name="value"\r\n\r\n'
b"Transf\xc3\xa9rer\r\n"
@@ -431,7 +431,7 @@ def test_missing_boundary_parameter(
"/",
data=(
# file
b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore
b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"<file content>\r\n"
),
@@ -459,7 +459,7 @@ def test_missing_name_parameter_on_content_disposition(
"/",
data=(
# data
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; ="field0"\r\n\r\n'
b"value0\r\n"
),
@@ -487,11 +487,7 @@ def test_too_many_fields_raise(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many fields. Maximum number of fields is 1000."

@@ -514,11 +510,7 @@ def test_too_many_files_raise(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."

@@ -543,11 +535,7 @@ def test_too_many_files_single_field_raise(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."

@@ -571,11 +559,7 @@ def test_too_many_files_and_fields_raise(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."

@@ -601,11 +585,7 @@ def test_max_fields_is_customizable_low_raises(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many fields. Maximum number of fields is 1."

@@ -631,11 +611,7 @@ def test_max_files_is_customizable_low_raises(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1."

@@ -650,11 +626,7 @@ def test_max_fields_is_customizable_high(
fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
data += b"--B--\r\n"
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
res = client.post("/", data=data, headers={"Content-Type": ("multipart/form-data; boundary=B")})
assert res.status_code == 200
res_data = res.json()
assert res_data["N1999"] == ""
10 changes: 5 additions & 5 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}

response = client.post("/", data="abc") # type: ignore
response = client.post("/", data="abc")
assert response.json() == {"body": "abc"}


@@ -114,7 +114,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}

response = client.post("/", data="abc") # type: ignore
response = client.post("/", data="abc")
assert response.json() == {"body": "abc"}


@@ -156,7 +156,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:

client = test_client_factory(app)

response = client.post("/", data="abc") # type: ignore
response = client.post("/", data="abc")
assert response.json() == {"body": "abc", "stream": "abc"}


@@ -175,7 +175,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:

client = test_client_factory(app)

response = client.post("/", data="abc") # type: ignore
response = client.post("/", data="abc")
assert response.json() == {"body": "<stream consumed>", "stream": "abc"}


@@ -462,7 +462,7 @@ def post_body() -> Iterator[bytes]:
yield b"foo"
yield b"bar"

response = client.post("/", data=post_body()) # type: ignore
response = client.post("/", data=post_body())
assert response.json() == {"body": "foobar"}