Skip to content

Commit

Permalink
More resiliant is_readable testing (#311)
Browse files Browse the repository at this point in the history
* More resiliant is_readable testing
  • Loading branch information
tomchristie authored Apr 28, 2021
1 parent 3acf913 commit 8b2053e
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def aclose(self) -> None:

def is_readable(self) -> bool:
sock = self.stream.extra(SocketAttribute.raw_socket)
return is_socket_readable(sock.fileno())
return is_socket_readable(sock)


class Lock(AsyncLock):
Expand Down
7 changes: 2 additions & 5 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ssl import SSLContext
from typing import Optional

from .. import _utils
from .._exceptions import (
ConnectError,
ConnectTimeout,
Expand All @@ -14,6 +13,7 @@
map_exceptions,
)
from .._types import TimeoutDict
from .._utils import is_socket_readable
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

SSL_MONKEY_PATCH_APPLIED = False
Expand Down Expand Up @@ -201,10 +201,7 @@ async def aclose(self) -> None:
def is_readable(self) -> bool:
transport = self.stream_reader._transport # type: ignore
sock: Optional[socket.socket] = transport.get_extra_info("socket")
# If socket was detached from the transport, most likely connection was reset.
# Hence make it readable to notify users to poll the socket.
# We'd expect the read operation to return `b""` indicating the socket closure.
return sock is None or _utils.is_socket_readable(sock.fileno())
return is_socket_readable(sock)


class Lock(AsyncLock):
Expand Down
2 changes: 1 addition & 1 deletion httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def aclose(self) -> None:
await self.socket.close()

def is_readable(self) -> bool:
return is_socket_readable(self.socket.fileno())
return is_socket_readable(self.socket)


class CurioBackend(AsyncBackend):
Expand Down
2 changes: 1 addition & 1 deletion httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def close(self) -> None:
pass

def is_readable(self) -> bool:
return is_socket_readable(self.sock.fileno())
return is_socket_readable(self.sock)


class SyncLock:
Expand Down
10 changes: 9 additions & 1 deletion httpcore/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import select
import socket
import sys
import typing

Expand Down Expand Up @@ -73,7 +74,7 @@ def exponential_backoff(factor: float) -> typing.Iterator[float]:
yield factor * (2 ** (n - 2))


def is_socket_readable(sock_fd: int) -> bool:
def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool:
"""
Return whether a socket, as identifed by its file descriptor, is readable.
Expand All @@ -83,6 +84,13 @@ def is_socket_readable(sock_fd: int) -> bool:
# NOTE: we want check for readability without actually attempting to read, because
# we don't want to block forever if it's not readable.

# In the case that the socket no longer exists, or cannot return a file
# descriptor, we treat it as being readable, as if it the next read operation
# on it is ready to return the terminating `b""`.
sock_fd = None if sock is None else sock.fileno()
if sock_fd is None or sock_fd < 0:
return True

# The implementation below was stolen from:
# https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478
# See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316
Expand Down
7 changes: 6 additions & 1 deletion tests/backend_tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from httpcore._backends.asyncio import SocketStream


class MockSocket:
def fileno(self):
return 1


class TestSocketStream:
class TestIsReadable:
@pytest.mark.asyncio
Expand All @@ -18,7 +23,7 @@ async def test_returns_true_when_transport_has_no_socket(self):
@pytest.mark.asyncio
async def test_returns_true_when_socket_is_readable(self):
stream_reader = MagicMock()
stream_reader._transport.get_extra_info.return_value = MagicMock()
stream_reader._transport.get_extra_info.return_value = MockSocket()
sock_stream = SocketStream(stream_reader, MagicMock())

with patch(
Expand Down

0 comments on commit 8b2053e

Please sign in to comment.