Skip to content

Commit 495467e

Browse files
Add asyncio backend. Add integration testing
1 parent da86ca4 commit 495467e

10 files changed

+513
-36
lines changed

httpcore/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
AsyncHTTPProxy,
99
AsyncSOCKSProxy,
1010
)
11+
from ._backends.asyncio import AsyncioBackend
12+
from ._backends.auto import AutoBackend
1113
from ._backends.base import (
1214
SOCKET_OPTION,
1315
AsyncNetworkBackend,
@@ -97,6 +99,8 @@ def __init__(self, *args, **kwargs): # type: ignore
9799
"SOCKSProxy",
98100
# network backends, implementations
99101
"SyncBackend",
102+
"AutoBackend",
103+
"AsyncioBackend",
100104
"AnyIOBackend",
101105
"TrioBackend",
102106
# network backends, mock implementations

httpcore/_backends/anyio.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ async def connect_tcp(
105105
local_address: typing.Optional[str] = None,
106106
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
107107
) -> AsyncNetworkStream:
108-
if socket_options is None:
109-
socket_options = [] # pragma: no cover
110108
exc_map = {
111109
TimeoutError: ConnectTimeout,
112110
OSError: ConnectError,
@@ -120,8 +118,7 @@ async def connect_tcp(
120118
local_host=local_address,
121119
)
122120
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
123-
for option in socket_options:
124-
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
121+
self._set_socket_options(stream, socket_options)
125122
return AnyIOStream(stream)
126123

127124
async def connect_unix_socket(
@@ -130,8 +127,6 @@ async def connect_unix_socket(
130127
timeout: typing.Optional[float] = None,
131128
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
132129
) -> AsyncNetworkStream: # pragma: nocover
133-
if socket_options is None:
134-
socket_options = []
135130
exc_map = {
136131
TimeoutError: ConnectTimeout,
137132
OSError: ConnectError,
@@ -140,9 +135,23 @@ async def connect_unix_socket(
140135
with map_exceptions(exc_map):
141136
with anyio.fail_after(timeout):
142137
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
143-
for option in socket_options:
144-
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
138+
self._set_socket_options(stream, socket_options)
145139
return AnyIOStream(stream)
146140

147141
async def sleep(self, seconds: float) -> None:
148142
await anyio.sleep(seconds) # pragma: nocover
143+
144+
def _set_socket_options(
145+
self,
146+
stream: anyio.abc.ByteStream,
147+
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
148+
) -> None:
149+
if not socket_options:
150+
return
151+
152+
sock = stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
153+
if sock is None:
154+
raise ConnectError("Broken socket for setsockopt") # pragma: nocover
155+
156+
for option in socket_options:
157+
sock.setsockopt(*option)

httpcore/_backends/asyncio.py

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import asyncio
2+
import socket
3+
import ssl
4+
from typing import Any, Dict, Iterable, Optional, Type
5+
6+
from .._exceptions import (
7+
ConnectError,
8+
ConnectTimeout,
9+
ReadError,
10+
ReadTimeout,
11+
WriteError,
12+
WriteTimeout,
13+
map_exceptions,
14+
)
15+
from .._utils import is_socket_readable
16+
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
17+
18+
19+
class AsyncIOStream(AsyncNetworkStream):
20+
def __init__(
21+
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
22+
):
23+
self._stream_reader = stream_reader
24+
self._stream_writer = stream_writer
25+
self._read_lock = asyncio.Lock()
26+
self._write_lock = asyncio.Lock()
27+
self._inner: Optional[AsyncIOStream] = None
28+
29+
async def start_tls(
30+
self,
31+
ssl_context: ssl.SSLContext,
32+
server_hostname: Optional[str] = None,
33+
timeout: Optional[float] = None,
34+
) -> AsyncNetworkStream:
35+
loop = asyncio.get_event_loop()
36+
37+
stream_reader = asyncio.StreamReader()
38+
protocol = asyncio.StreamReaderProtocol(stream_reader)
39+
40+
exc_map: Dict[Type[Exception], Type[Exception]] = {
41+
asyncio.TimeoutError: ConnectTimeout,
42+
OSError: ConnectError,
43+
}
44+
with map_exceptions(exc_map):
45+
transport_ssl = await asyncio.wait_for(
46+
loop.start_tls(
47+
self._stream_writer.transport,
48+
protocol,
49+
ssl_context,
50+
server_hostname=server_hostname,
51+
),
52+
timeout,
53+
)
54+
if transport_ssl is None:
55+
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.start_tls
56+
raise ConnectError("Transport closed while starting TLS") # pragma: nocover
57+
58+
# Initialize the protocol, so it is made aware of being tied to
59+
# a TLS connection.
60+
# See: https://github.com/encode/httpx/issues/859
61+
protocol.connection_made(transport_ssl)
62+
63+
stream_writer = asyncio.StreamWriter(
64+
transport=transport_ssl, protocol=protocol, reader=stream_reader, loop=loop
65+
)
66+
67+
ssl_stream = AsyncIOStream(stream_reader, stream_writer)
68+
# When we return a new SocketStream with new StreamReader/StreamWriter instances
69+
# we need to keep references to the old StreamReader/StreamWriter so that they
70+
# are not garbage collected and closed while we're still using them.
71+
ssl_stream._inner = self
72+
return ssl_stream
73+
74+
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
75+
exc_map: Dict[Type[Exception], Type[Exception]] = {
76+
asyncio.TimeoutError: ReadTimeout,
77+
OSError: ReadError,
78+
}
79+
async with self._read_lock:
80+
with map_exceptions(exc_map):
81+
try:
82+
return await asyncio.wait_for(
83+
self._stream_reader.read(max_bytes), timeout
84+
)
85+
except AttributeError as exc: # pragma: nocover
86+
if "resume_reading" in str(exc):
87+
# Python's asyncio has a bug that can occur when a
88+
# connection has been closed, while it is paused.
89+
# See: https://github.com/encode/httpx/issues/1213
90+
#
91+
# Returning an empty byte-string to indicate connection
92+
# close will eventually raise an httpcore.RemoteProtocolError
93+
# to the user when this goes through our HTTP parsing layer.
94+
return b""
95+
raise
96+
97+
async def write(self, data: bytes, timeout: Optional[float] = None) -> None:
98+
if not data:
99+
return
100+
101+
exc_map: Dict[Type[Exception], Type[Exception]] = {
102+
asyncio.TimeoutError: WriteTimeout,
103+
OSError: WriteError,
104+
}
105+
async with self._write_lock:
106+
with map_exceptions(exc_map):
107+
self._stream_writer.write(data)
108+
return await asyncio.wait_for(self._stream_writer.drain(), timeout)
109+
110+
async def aclose(self) -> None:
111+
# SSL connections should issue the close and then abort, rather than
112+
# waiting for the remote end of the connection to signal the EOF.
113+
#
114+
# See:
115+
#
116+
# * https://bugs.python.org/issue39758
117+
# * https://github.com/python-trio/trio/blob/
118+
# 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829
119+
#
120+
# And related issues caused if we simply omit the 'wait_closed' call,
121+
# without first using `.abort()`
122+
#
123+
# * https://github.com/encode/httpx/issues/825
124+
# * https://github.com/encode/httpx/issues/914
125+
is_ssl = self._sslobj is not None
126+
127+
async with self._write_lock:
128+
try:
129+
self._stream_writer.close()
130+
if is_ssl:
131+
# Give the connection a chance to write any data in the buffer,
132+
# and then forcibly tear down the SSL connection.
133+
await asyncio.sleep(0)
134+
self._stream_writer.transport.abort()
135+
await self._stream_writer.wait_closed()
136+
except OSError: # pragma: nocover
137+
pass
138+
139+
def get_extra_info(self, info: str) -> Any:
140+
if info == "ssl_object":
141+
return self._sslobj
142+
if info == "client_addr":
143+
return self._raw_socket.getsockname()
144+
if info == "server_addr":
145+
return self._raw_socket.getpeername()
146+
if info == "socket":
147+
return self._raw_socket
148+
if info == "is_readable":
149+
return is_socket_readable(self._raw_socket)
150+
return None
151+
152+
@property
153+
def _raw_socket(self) -> socket.socket:
154+
transport = self._stream_writer.transport
155+
sock: socket.socket = transport.get_extra_info("socket")
156+
return sock
157+
158+
@property
159+
def _sslobj(self) -> Optional[ssl.SSLObject]:
160+
transport = self._stream_writer.transport
161+
sslobj: Optional[ssl.SSLObject] = transport.get_extra_info("ssl_object")
162+
return sslobj
163+
164+
165+
class AsyncioBackend(AsyncNetworkBackend):
166+
async def connect_tcp(
167+
self,
168+
host: str,
169+
port: int,
170+
timeout: Optional[float] = None,
171+
local_address: Optional[str] = None,
172+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
173+
) -> AsyncNetworkStream:
174+
local_addr = None if local_address is None else (local_address, 0)
175+
176+
exc_map: Dict[Type[Exception], Type[Exception]] = {
177+
asyncio.TimeoutError: ConnectTimeout,
178+
OSError: ConnectError,
179+
}
180+
with map_exceptions(exc_map):
181+
stream_reader, stream_writer = await asyncio.wait_for(
182+
asyncio.open_connection(host, port, local_addr=local_addr),
183+
timeout,
184+
)
185+
self._set_socket_options(stream_writer, socket_options)
186+
return AsyncIOStream(
187+
stream_reader=stream_reader, stream_writer=stream_writer
188+
)
189+
190+
async def connect_unix_socket(
191+
self,
192+
path: str,
193+
timeout: Optional[float] = None,
194+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
195+
) -> AsyncNetworkStream: # pragma: nocover
196+
exc_map: Dict[Type[Exception], Type[Exception]] = {
197+
asyncio.TimeoutError: ConnectTimeout,
198+
OSError: ConnectError,
199+
}
200+
with map_exceptions(exc_map):
201+
stream_reader, stream_writer = await asyncio.wait_for(
202+
asyncio.open_unix_connection(path), timeout
203+
)
204+
self._set_socket_options(stream_writer, socket_options)
205+
return AsyncIOStream(
206+
stream_reader=stream_reader, stream_writer=stream_writer
207+
)
208+
209+
async def sleep(self, seconds: float) -> None:
210+
await asyncio.sleep(seconds) # pragma: nocover
211+
212+
def _set_socket_options(
213+
self,
214+
stream: asyncio.StreamWriter,
215+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
216+
) -> None:
217+
if not socket_options:
218+
return
219+
220+
sock = stream.get_extra_info("socket")
221+
if sock is None:
222+
raise ConnectError("Broken socket for setsockopt") # pragma: nocover
223+
224+
for option in socket_options:
225+
sock.setsockopt(*option)

httpcore/_backends/auto.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,41 @@
11
import typing
2-
from typing import Optional
2+
from importlib.util import find_spec
3+
from typing import Optional, Type
34

4-
from .._synchronization import current_async_library
5+
from .._synchronization import current_async_backend
56
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
67

8+
HAS_ANYIO = find_spec("anyio") is not None
9+
710

811
class AutoBackend(AsyncNetworkBackend):
12+
@staticmethod
13+
def set_default_backend(backend_class: Optional[Type[AsyncNetworkBackend]]) -> None:
14+
setattr(AutoBackend, "_default_backend_class", backend_class)
15+
916
async def _init_backend(self) -> None:
10-
if not (hasattr(self, "_backend")):
11-
backend = current_async_library()
12-
if backend == "trio":
13-
from .trio import TrioBackend
17+
if hasattr(self, "_backend"):
18+
return
19+
20+
default_backend_class: Optional[Type[AsyncNetworkBackend]] = getattr(
21+
AutoBackend, "_default_backend_class", None
22+
)
23+
if default_backend_class is not None:
24+
self._backend = default_backend_class()
25+
return
26+
27+
if current_async_backend() == "trio":
28+
from .trio import TrioBackend
29+
30+
self._backend = TrioBackend()
31+
elif HAS_ANYIO:
32+
from .anyio import AnyIOBackend
1433

15-
self._backend: AsyncNetworkBackend = TrioBackend()
16-
else:
17-
from .anyio import AnyIOBackend
34+
self._backend = AnyIOBackend()
35+
else:
36+
from .asyncio import AsyncioBackend
1837

19-
self._backend = AnyIOBackend()
38+
self._backend = AsyncioBackend()
2039

2140
async def connect_tcp(
2241
self,

0 commit comments

Comments
 (0)