Skip to content

Commit 81b6b1f

Browse files
Add asyncio backend. Add integration testing
1 parent da86ca4 commit 81b6b1f

13 files changed

+627
-51
lines changed

docs/exceptions.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The following exceptions may be raised when sending a request:
99
* `httpcore.WriteTimeout`
1010
* `httpcore.NetworkError`
1111
* `httpcore.ConnectError`
12+
* `httpcore.BrokenSocketError`
1213
* `httpcore.ReadError`
1314
* `httpcore.WriteError`
1415
* `httpcore.ProtocolError`

httpcore/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
AsyncHTTPProxy,
99
AsyncSOCKSProxy,
1010
)
11+
from ._backends.asyncio import AsyncioBackend
12+
from ._backends.auto import AutoBackend
1113
from ._backends.base import (
1214
SOCKET_OPTION,
1315
AsyncNetworkBackend,
@@ -18,6 +20,7 @@
1820
from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream
1921
from ._backends.sync import SyncBackend
2022
from ._exceptions import (
23+
BrokenSocketError,
2124
ConnectError,
2225
ConnectionNotAvailable,
2326
ConnectTimeout,
@@ -97,6 +100,8 @@ def __init__(self, *args, **kwargs): # type: ignore
97100
"SOCKSProxy",
98101
# network backends, implementations
99102
"SyncBackend",
103+
"AutoBackend",
104+
"AsyncioBackend",
100105
"AnyIOBackend",
101106
"TrioBackend",
102107
# network backends, mock implementations
@@ -126,6 +131,7 @@ def __init__(self, *args, **kwargs): # type: ignore
126131
"WriteTimeout",
127132
"NetworkError",
128133
"ConnectError",
134+
"BrokenSocketError",
129135
"ReadError",
130136
"WriteError",
131137
]

httpcore/_backends/anyio.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import anyio
55

66
from .._exceptions import (
7+
BrokenSocketError,
78
ConnectError,
89
ConnectTimeout,
910
ReadError,
@@ -82,6 +83,9 @@ async def start_tls(
8283
return AnyIOStream(ssl_stream)
8384

8485
def get_extra_info(self, info: str) -> typing.Any:
86+
if info == "is_readable":
87+
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
88+
return is_socket_readable(sock)
8589
if info == "ssl_object":
8690
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
8791
if info == "client_addr":
@@ -90,9 +94,6 @@ def get_extra_info(self, info: str) -> typing.Any:
9094
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
9195
if info == "socket":
9296
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
93-
if info == "is_readable":
94-
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
95-
return is_socket_readable(sock)
9697
return None
9798

9899

@@ -105,8 +106,6 @@ async def connect_tcp(
105106
local_address: typing.Optional[str] = None,
106107
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
107108
) -> AsyncNetworkStream:
108-
if socket_options is None:
109-
socket_options = [] # pragma: no cover
110109
exc_map = {
111110
TimeoutError: ConnectTimeout,
112111
OSError: ConnectError,
@@ -120,18 +119,15 @@ async def connect_tcp(
120119
local_host=local_address,
121120
)
122121
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
123-
for option in socket_options:
124-
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
122+
self._set_socket_options(stream, socket_options)
125123
return AnyIOStream(stream)
126124

127125
async def connect_unix_socket(
128126
self,
129127
path: str,
130128
timeout: typing.Optional[float] = None,
131129
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
132-
) -> AsyncNetworkStream: # pragma: nocover
133-
if socket_options is None:
134-
socket_options = []
130+
) -> AsyncNetworkStream:
135131
exc_map = {
136132
TimeoutError: ConnectTimeout,
137133
OSError: ConnectError,
@@ -140,9 +136,23 @@ async def connect_unix_socket(
140136
with map_exceptions(exc_map):
141137
with anyio.fail_after(timeout):
142138
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
143-
for option in socket_options:
144-
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
139+
self._set_socket_options(stream, socket_options)
145140
return AnyIOStream(stream)
146141

147142
async def sleep(self, seconds: float) -> None:
148143
await anyio.sleep(seconds) # pragma: nocover
144+
145+
def _set_socket_options(
146+
self,
147+
stream: anyio.abc.ByteStream,
148+
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
149+
) -> None:
150+
if not socket_options:
151+
return
152+
153+
sock = stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
154+
if sock is None:
155+
raise BrokenSocketError() # pragma: nocover
156+
157+
for option in socket_options:
158+
sock.setsockopt(*option)

httpcore/_backends/asyncio.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import asyncio
2+
import socket
3+
import ssl
4+
from typing import Any, Dict, Iterable, Optional, Type
5+
6+
from .._exceptions import (
7+
BrokenSocketError,
8+
ConnectError,
9+
ConnectTimeout,
10+
ReadError,
11+
ReadTimeout,
12+
WriteError,
13+
WriteTimeout,
14+
map_exceptions,
15+
)
16+
from .._utils import is_socket_readable
17+
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
18+
19+
20+
class AsyncIOStream(AsyncNetworkStream):
21+
def __init__(
22+
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
23+
):
24+
self._stream_reader = stream_reader
25+
self._stream_writer = stream_writer
26+
self._read_lock = asyncio.Lock()
27+
self._write_lock = asyncio.Lock()
28+
self._inner: Optional[AsyncIOStream] = None
29+
30+
async def start_tls(
31+
self,
32+
ssl_context: ssl.SSLContext,
33+
server_hostname: Optional[str] = None,
34+
timeout: Optional[float] = None,
35+
) -> AsyncNetworkStream:
36+
loop = asyncio.get_event_loop()
37+
38+
stream_reader = asyncio.StreamReader()
39+
protocol = asyncio.StreamReaderProtocol(stream_reader)
40+
41+
exc_map: Dict[Type[Exception], Type[Exception]] = {
42+
asyncio.TimeoutError: ConnectTimeout,
43+
OSError: ConnectError,
44+
}
45+
with map_exceptions(exc_map):
46+
transport_ssl = await asyncio.wait_for(
47+
loop.start_tls(
48+
self._stream_writer.transport,
49+
protocol,
50+
ssl_context,
51+
server_hostname=server_hostname,
52+
),
53+
timeout,
54+
)
55+
if transport_ssl is None:
56+
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.start_tls
57+
raise ConnectError("Transport closed while starting TLS") # pragma: nocover
58+
59+
# Initialize the protocol, so it is made aware of being tied to
60+
# a TLS connection.
61+
# See: https://github.com/encode/httpx/issues/859
62+
protocol.connection_made(transport_ssl)
63+
64+
stream_writer = asyncio.StreamWriter(
65+
transport=transport_ssl, protocol=protocol, reader=stream_reader, loop=loop
66+
)
67+
68+
ssl_stream = AsyncIOStream(stream_reader, stream_writer)
69+
# When we return a new SocketStream with new StreamReader/StreamWriter instances
70+
# we need to keep references to the old StreamReader/StreamWriter so that they
71+
# are not garbage collected and closed while we're still using them.
72+
ssl_stream._inner = self
73+
return ssl_stream
74+
75+
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
76+
exc_map: Dict[Type[Exception], Type[Exception]] = {
77+
asyncio.TimeoutError: ReadTimeout,
78+
OSError: ReadError,
79+
}
80+
async with self._read_lock:
81+
with map_exceptions(exc_map):
82+
try:
83+
return await asyncio.wait_for(
84+
self._stream_reader.read(max_bytes), timeout
85+
)
86+
except AttributeError as exc: # pragma: nocover
87+
if "resume_reading" in str(exc):
88+
# Python's asyncio has a bug that can occur when a
89+
# connection has been closed, while it is paused.
90+
# See: https://github.com/encode/httpx/issues/1213
91+
#
92+
# Returning an empty byte-string to indicate connection
93+
# close will eventually raise an httpcore.RemoteProtocolError
94+
# to the user when this goes through our HTTP parsing layer.
95+
return b""
96+
raise
97+
98+
async def write(self, data: bytes, timeout: Optional[float] = None) -> None:
99+
if not data:
100+
return
101+
102+
exc_map: Dict[Type[Exception], Type[Exception]] = {
103+
asyncio.TimeoutError: WriteTimeout,
104+
OSError: WriteError,
105+
}
106+
async with self._write_lock:
107+
with map_exceptions(exc_map):
108+
self._stream_writer.write(data)
109+
return await asyncio.wait_for(self._stream_writer.drain(), timeout)
110+
111+
async def aclose(self) -> None:
112+
# SSL connections should issue the close and then abort, rather than
113+
# waiting for the remote end of the connection to signal the EOF.
114+
#
115+
# See:
116+
#
117+
# * https://bugs.python.org/issue39758
118+
# * https://github.com/python-trio/trio/blob/
119+
# 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829
120+
#
121+
# And related issues caused if we simply omit the 'wait_closed' call,
122+
# without first using `.abort()`
123+
#
124+
# * https://github.com/encode/httpx/issues/825
125+
# * https://github.com/encode/httpx/issues/914
126+
is_ssl = self._sslobj is not None
127+
128+
async with self._write_lock:
129+
try:
130+
self._stream_writer.close()
131+
if is_ssl:
132+
# Give the connection a chance to write any data in the buffer,
133+
# and then forcibly tear down the SSL connection.
134+
await asyncio.sleep(0)
135+
self._stream_writer.transport.abort()
136+
await self._stream_writer.wait_closed()
137+
except OSError: # pragma: nocover
138+
pass
139+
140+
def get_extra_info(self, info: str) -> Any:
141+
if info == "is_readable":
142+
return is_socket_readable(self._raw_socket)
143+
if info == "ssl_object":
144+
return self._sslobj
145+
if info in ("client_addr", "server_addr"):
146+
sock = self._raw_socket
147+
if sock is None:
148+
raise BrokenSocketError() # pragma: nocover
149+
return sock.getsockname() if info == "client_addr" else sock.getpeername()
150+
if info == "socket":
151+
return self._raw_socket
152+
return None
153+
154+
@property
155+
def _raw_socket(self) -> Optional[socket.socket]:
156+
transport = self._stream_writer.transport
157+
sock: Optional[socket.socket] = transport.get_extra_info("socket")
158+
return sock
159+
160+
@property
161+
def _sslobj(self) -> Optional[ssl.SSLObject]:
162+
transport = self._stream_writer.transport
163+
sslobj: Optional[ssl.SSLObject] = transport.get_extra_info("ssl_object")
164+
return sslobj
165+
166+
167+
class AsyncioBackend(AsyncNetworkBackend):
168+
async def connect_tcp(
169+
self,
170+
host: str,
171+
port: int,
172+
timeout: Optional[float] = None,
173+
local_address: Optional[str] = None,
174+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
175+
) -> AsyncNetworkStream:
176+
local_addr = None if local_address is None else (local_address, 0)
177+
178+
exc_map: Dict[Type[Exception], Type[Exception]] = {
179+
asyncio.TimeoutError: ConnectTimeout,
180+
OSError: ConnectError,
181+
}
182+
with map_exceptions(exc_map):
183+
stream_reader, stream_writer = await asyncio.wait_for(
184+
asyncio.open_connection(host, port, local_addr=local_addr),
185+
timeout,
186+
)
187+
self._set_socket_options(stream_writer, socket_options)
188+
return AsyncIOStream(
189+
stream_reader=stream_reader, stream_writer=stream_writer
190+
)
191+
192+
async def connect_unix_socket(
193+
self,
194+
path: str,
195+
timeout: Optional[float] = None,
196+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
197+
) -> AsyncNetworkStream:
198+
exc_map: Dict[Type[Exception], Type[Exception]] = {
199+
asyncio.TimeoutError: ConnectTimeout,
200+
OSError: ConnectError,
201+
}
202+
with map_exceptions(exc_map):
203+
stream_reader, stream_writer = await asyncio.wait_for(
204+
asyncio.open_unix_connection(path), timeout
205+
)
206+
self._set_socket_options(stream_writer, socket_options)
207+
return AsyncIOStream(
208+
stream_reader=stream_reader, stream_writer=stream_writer
209+
)
210+
211+
async def sleep(self, seconds: float) -> None:
212+
await asyncio.sleep(seconds) # pragma: nocover
213+
214+
def _set_socket_options(
215+
self,
216+
stream: asyncio.StreamWriter,
217+
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
218+
) -> None:
219+
if not socket_options:
220+
return
221+
222+
sock = stream.get_extra_info("socket")
223+
if sock is None:
224+
raise BrokenSocketError() # pragma: nocover
225+
226+
for option in socket_options:
227+
sock.setsockopt(*option)

0 commit comments

Comments
 (0)