Skip to content
10 changes: 6 additions & 4 deletions httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Type definitions for type checking purposes.
"""

import enum
import ssl
from http.cookiejar import CookieJar
from typing import (
Expand Down Expand Up @@ -30,7 +31,6 @@
from ._models import Cookies, Headers, Request # noqa: F401
from ._urls import URL, QueryParams # noqa: F401


PrimitiveData = Optional[Union[str, int, float, bool]]

RawURL = NamedTuple(
Expand All @@ -45,11 +45,13 @@

URLTypes = Union["URL", str]

QueryPrimitiveData = Optional[Union[str, int, float, bool, None, enum.Enum]]

QueryParamTypes = Union[
"QueryParams",
Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]],
List[Tuple[str, PrimitiveData]],
Tuple[Tuple[str, PrimitiveData], ...],
Mapping[str, Union[QueryPrimitiveData, Sequence[QueryPrimitiveData]]],
List[Tuple[str, QueryPrimitiveData]],
Tuple[Tuple[str, QueryPrimitiveData], ...],
str,
bytes,
]
Expand Down
8 changes: 6 additions & 2 deletions httpx/_urlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ipaddress
import re
import typing
from urllib.parse import quote_from_bytes

import idna

Expand Down Expand Up @@ -423,10 +424,13 @@ def PERCENT(string: str) -> str:
return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")])


def percent_encoded(string: str, safe: str = "/") -> str:
def percent_encoded(string: str | bytes, safe: str = "/") -> str:
"""
Use percent-encoding to quote a string.
"""
if isinstance(string, bytes):
return quote_from_bytes(string)

NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe

# Fast path for strings that don't need escaping.
Expand Down Expand Up @@ -471,7 +475,7 @@ def quote(string: str, safe: str = "/") -> str:
return "".join(parts)


def urlencode(items: list[tuple[str, str]]) -> str:
def urlencode(items: list[tuple[str, str | bytes]]) -> str:
"""
We can use a much simpler version of the stdlib urlencode here because
we don't need to handle a bunch of different typing cases, such as bytes vs str.
Expand Down
52 changes: 36 additions & 16 deletions httpx/_urls.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import copy
import typing
from urllib.parse import parse_qs, unquote

import idna

from ._types import QueryParamTypes, RawURL, URLTypes
from ._urlparse import urlencode, urlparse
from ._utils import primitive_value_to_str
from ._utils import encode_query_value

__all__ = ["URL", "QueryParams"]

Expand Down Expand Up @@ -417,20 +418,30 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({url!r})"


class QueryParams(typing.Mapping[str, str]):
class QueryParams(typing.Mapping[str, typing.Union[str, bytes]]):
"""
URL query parameters, as a multi-dict.
"""

def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
_dict: dict[str, list[str | bytes]]

__slots__ = ("_dict",)

@typing.overload
def __init__(self, qs: QueryParamTypes | None, /) -> None: ...

@typing.overload
def __init__(self, /, **kwargs: typing.Any) -> None: ...

def __init__(self, /, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
assert len(args) < 2, "Too many arguments."
assert not (args and kwargs), "Cannot mix named and unnamed arguments."

value = args[0] if args else kwargs

if value is None or isinstance(value, (str, bytes)):
value = value.decode("ascii") if isinstance(value, bytes) else value
self._dict = parse_qs(value, keep_blank_values=True)
self._dict = parse_qs(value, keep_blank_values=True) # type: ignore[assignment]
elif isinstance(value, QueryParams):
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
Expand All @@ -456,7 +467,7 @@ def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
# We coerce values `True` and `False` to JSON-like "true" and "false"
# representations, and coerce `None` values to the empty string.
self._dict = {
str(k): [primitive_value_to_str(item) for item in v]
str(k): [encode_query_value(item) for item in v]
for k, v in dict_value.items()
}

Expand All @@ -471,7 +482,7 @@ def keys(self) -> typing.KeysView[str]:
"""
return self._dict.keys()

def values(self) -> typing.ValuesView[str]:
def values(self) -> typing.ValuesView[str | bytes]:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.
Expand All @@ -483,7 +494,7 @@ def values(self) -> typing.ValuesView[str]:
"""
return {k: v[0] for k, v in self._dict.items()}.values()

def items(self) -> typing.ItemsView[str, str]:
def items(self) -> typing.ItemsView[str, str | bytes]:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
Expand All @@ -495,7 +506,7 @@ def items(self) -> typing.ItemsView[str, str]:
"""
return {k: v[0] for k, v in self._dict.items()}.items()

def multi_items(self) -> list[tuple[str, str]]:
def multi_items(self) -> list[tuple[str, str | bytes]]:
"""
Return all items in the query params. Allow duplicate keys to occur.

Expand All @@ -504,7 +515,7 @@ def multi_items(self) -> list[tuple[str, str]]:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: list[tuple[str, str]] = []
multi_items: list[tuple[str, str | bytes]] = []
for k, v in self._dict.items():
multi_items.extend([(k, i) for i in v])
return multi_items
Expand All @@ -523,7 +534,7 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
return self._dict[str(key)][0]
return default

def get_list(self, key: str) -> list[str]:
def get_list(self, key: str) -> list[str | bytes]:
"""
Get all values from the query param for a given key.

Expand All @@ -545,8 +556,8 @@ def set(self, key: str, value: typing.Any = None) -> QueryParams:
assert q == httpx.QueryParams("a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = [primitive_value_to_str(value)]
q._dict = copy.deepcopy(self._dict)
q._dict[str(key)] = [encode_query_value(value)]
return q

def add(self, key: str, value: typing.Any = None) -> QueryParams:
Expand All @@ -560,8 +571,8 @@ def add(self, key: str, value: typing.Any = None) -> QueryParams:
assert q == httpx.QueryParams("a=123&a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
q._dict = copy.deepcopy(self._dict)
q._dict[str(key)] = q.get_list(key) + [encode_query_value(value)]
return q

def remove(self, key: str) -> QueryParams:
Expand All @@ -575,7 +586,7 @@ def remove(self, key: str) -> QueryParams:
assert q == httpx.QueryParams("")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict = copy.deepcopy(self._dict)
q._dict.pop(str(key), None)
return q

Expand All @@ -597,7 +608,7 @@ def merge(self, params: QueryParamTypes | None = None) -> QueryParams:
q._dict = {**self._dict, **q._dict}
return q

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: typing.Any) -> str | bytes:
return self._dict[key][0]

def __contains__(self, key: typing.Any) -> bool:
Expand Down Expand Up @@ -646,3 +657,12 @@ def __setitem__(self, key: str, value: str) -> None:
"QueryParams are immutable since 0.18.0. "
"Use `q = q.set(key, value)` to create an updated copy."
)


if typing.TYPE_CHECKING: # pragma: no cover
# assert typing error
QueryParams("q=a", {"q": "a"}) # type: ignore[call-overload]
QueryParams({"a": 1}, {"q": "a"}) # type: ignore[call-overload]
QueryParams("q=a")
QueryParams({"q": "a"})
QueryParams(q="a")
18 changes: 17 additions & 1 deletion httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import codecs
import email.message
import enum
import ipaddress
import mimetypes
import os
Expand All @@ -18,7 +19,6 @@
if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL


_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
Expand Down Expand Up @@ -68,6 +68,22 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return str(value)


def encode_query_value(value: typing.Any) -> str | bytes:
if isinstance(value, (str, bytes)):
return value
if value is True:
return "true"
if value is False:
return "false"
if value is None:
return ""
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, enum.Enum):
return encode_query_value(value.value)
raise TypeError(f"can't use {type(value)!r} as query value")


def is_known_encoding(encoding: str) -> bool:
"""
Return `True` if `encoding` is a known codec.
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_queryparams.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import enum
import re
import types

import pytest

import httpx
Expand Down Expand Up @@ -75,6 +79,13 @@ def test_queryparam_types():
q = httpx.QueryParams({"a": [1, 2]})
assert str(q) == "a=1&a=2"

class E(enum.Enum):
v1 = 1
v2 = "v2"

q = httpx.QueryParams({"a": [E.v1, E.v2]})
assert str(q) == "a=1&a=v2"


def test_empty_query_params():
q = httpx.QueryParams({"a": ""})
Expand Down Expand Up @@ -134,3 +145,13 @@ def test_queryparams_are_hashable():
)

assert len(set(params)) == 2


def test_queryparams_bytes():
q = httpx.QueryParams({"q": bytes.fromhex("E1EE0E2734986F5419BB6C")})
assert str(q) == "q=%E1%EE%0E%274%98oT%19%BBl"


def test_queryparams_error():
with pytest.raises(TypeError, match=re.compile(r"can't use .* as query value")):
httpx.QueryParams({"q": types.SimpleNamespace()}) # type: ignore
8 changes: 4 additions & 4 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def test_urlencoded_content():

@pytest.mark.anyio
async def test_urlencoded_boolean():
request = httpx.Request(method, url, data={"example": True})
request = httpx.Request(method, url, data={"example": True, "e2": False})
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for coverage

assert isinstance(request.stream, typing.Iterable)
assert isinstance(request.stream, typing.AsyncIterable)

Expand All @@ -209,11 +209,11 @@ async def test_urlencoded_boolean():

assert request.headers == {
"Host": "www.example.com",
"Content-Length": "12",
"Content-Length": "21",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"example=true"
assert async_content == b"example=true"
assert sync_content == b"example=true&e2=false"
assert async_content == b"example=true&e2=false"


@pytest.mark.anyio
Expand Down