Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename Stream to TCPStream #339

Merged
merged 1 commit into from
Sep 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseStream,
BaseTCPStream,
ConcurrencyBackend,
)
from .config import (
Expand Down Expand Up @@ -107,7 +107,7 @@
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
"BaseStream",
"BaseTCPStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
Expand Down
26 changes: 8 additions & 18 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
"""
The `Stream` class here provides a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.

Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.

These classes help encapsulate the timeout logic, make it easier to unit-test
protocols, and help keep the rest of the package more `async`/`await`
based, and less strictly `asyncio`-specific.
"""
import asyncio
import functools
import ssl
Expand All @@ -21,7 +11,7 @@
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseStream,
BaseTCPStream,
ConcurrencyBackend,
TimeoutFlag,
)
Expand Down Expand Up @@ -50,7 +40,7 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore
MonkeyPatch.write = _fixed_write


class Stream(BaseStream):
class TCPStream(BaseTCPStream):
def __init__(
self,
stream_reader: asyncio.StreamReader,
Expand Down Expand Up @@ -176,13 +166,13 @@ def loop(self) -> asyncio.AbstractEventLoop:
self._loop = asyncio.new_event_loop()
return self._loop

async def connect(
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseStream:
) -> BaseTCPStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
Expand All @@ -191,25 +181,25 @@ async def connect(
except asyncio.TimeoutError:
raise ConnectTimeout()

return Stream(
return TCPStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def start_tls(
self,
stream: BaseStream,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:
) -> BaseTCPStream:

loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)

assert isinstance(stream, Stream)
assert isinstance(stream, TCPStream)

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
Expand Down
12 changes: 6 additions & 6 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def set_write_timeouts(self) -> None:
self.raise_on_write_timeout = True


class BaseStream:
class BaseTCPStream:
"""
A stream with read/write operations. Abstracts away any asyncio-specific
A TCP stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""
Expand Down Expand Up @@ -110,22 +110,22 @@ def release(self) -> None:


class ConcurrencyBackend:
async def connect(
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseStream:
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover

async def start_tls(
self,
stream: BaseStream,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def connect(
on_release = functools.partial(self.release_func, self)

logger.debug(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
stream = await self.backend.connect(host, port, ssl_context, timeout)
stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
http_version = stream.get_http_version()
logger.debug(f"connected http_version={http_version!r}")

Expand Down
4 changes: 2 additions & 2 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h11

from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger
Expand Down Expand Up @@ -31,7 +31,7 @@ class HTTP11Connection:

def __init__(
self,
stream: BaseStream,
stream: BaseTCPStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
Expand Down
4 changes: 2 additions & 2 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import h2.connection
import h2.events

from ..concurrency.base import BaseEvent, BaseStream, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger
Expand All @@ -17,7 +17,7 @@ class HTTP2Connection:

def __init__(
self,
stream: BaseStream,
stream: BaseTCPStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
Expand Down
10 changes: 5 additions & 5 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import h2.connection
import h2.events

from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig
from httpx import AsyncioBackend, BaseTCPStream, Request, TimeoutConfig
from tests.concurrency import sleep


Expand All @@ -15,13 +15,13 @@ def __init__(self, app, backend=None):
self.backend = AsyncioBackend() if backend is None else backend
self.server = None

async def connect(
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseStream:
) -> BaseTCPStream:
self.server = MockHTTP2Server(self.app, backend=self.backend)
return self.server

Expand All @@ -30,7 +30,7 @@ def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)


class MockHTTP2Server(BaseStream):
class MockHTTP2Server(BaseTCPStream):
def __init__(self, app, backend):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
Expand All @@ -42,7 +42,7 @@ def __init__(self, app, backend):
self.return_data = {}
self.returning = {}

# Stream interface
# TCP stream interface

def get_http_version(self) -> str:
return "HTTP/2"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_start_tls_on_socket_stream(https_server):
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)

stream = await backend.connect(
stream = await backend.open_tcp_stream(
https_server.url.host, https_server.url.port, None, timeout
)

Expand Down