Skip to content

Commit 76968d1

Browse files
committed
Add uds arg to BaseClient and select tcp or uds in HttpConnection
1 parent 1e40664 commit 76968d1

File tree

7 files changed

+63
-3
lines changed

7 files changed

+63
-3
lines changed

httpx/client.py

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
app: typing.Callable = None,
7575
backend: ConcurrencyBackend = None,
7676
trust_env: bool = True,
77+
uds: str = None,
7778
):
7879
if backend is None:
7980
backend = AsyncioBackend()
@@ -99,6 +100,7 @@ def __init__(
99100
pool_limits=pool_limits,
100101
backend=backend,
101102
trust_env=self.trust_env,
103+
uds=uds,
102104
)
103105
elif isinstance(dispatch, Dispatcher):
104106
async_dispatch = ThreadedDispatcher(dispatch, backend)

httpx/concurrency/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ async def open_uds_stream(
130130
hostname: typing.Optional[str],
131131
ssl_context: typing.Optional[ssl.SSLContext],
132132
timeout: TimeoutConfig,
133-
) -> BaseTCPStream:
133+
) -> BaseSocketStream:
134134
raise NotImplementedError() # pragma: no cover
135135

136136
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:

httpx/dispatch/connection.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ def __init__(
3838
http_versions: HTTPVersionTypes = None,
3939
backend: ConcurrencyBackend = None,
4040
release_func: typing.Optional[ReleaseCallback] = None,
41+
uds: typing.Optional[str] = None,
4142
):
4243
self.origin = Origin(origin) if isinstance(origin, str) else origin
4344
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
4445
self.timeout = TimeoutConfig(timeout)
4546
self.http_versions = HTTPVersionConfig(http_versions)
4647
self.backend = AsyncioBackend() if backend is None else backend
4748
self.release_func = release_func
49+
self.uds = uds
4850
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
4951
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
5052

@@ -84,8 +86,21 @@ async def connect(
8486
else:
8587
on_release = functools.partial(self.release_func, self)
8688

87-
logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
88-
stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
89+
if self.uds is None:
90+
logger.trace(
91+
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
92+
)
93+
stream = await self.backend.open_tcp_stream(
94+
host, port, ssl_context, timeout
95+
)
96+
else:
97+
logger.trace(
98+
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
99+
)
100+
stream = await self.backend.open_uds_stream(
101+
self.uds, host, ssl_context, timeout
102+
)
103+
89104
http_version = stream.get_http_version()
90105
logger.trace(f"connected http_version={http_version!r}")
91106

httpx/dispatch/connection_pool.py

+3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
9090
http_versions: HTTPVersionTypes = None,
9191
backend: ConcurrencyBackend = None,
92+
uds: typing.Optional[str] = None,
9293
):
9394
self.verify = verify
9495
self.cert = cert
@@ -97,6 +98,7 @@ def __init__(
9798
self.http_versions = http_versions
9899
self.is_closed = False
99100
self.trust_env = trust_env
101+
self.uds = uds
100102

101103
self.keepalive_connections = ConnectionStore()
102104
self.active_connections = ConnectionStore()
@@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
142144
backend=self.backend,
143145
release_func=self.release_connection,
144146
trust_env=self.trust_env,
147+
uds=self.uds,
145148
)
146149
logger.trace(f"new_connection connection={connection!r}")
147150
else:

tests/client/test_async_client.py

+14
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,17 @@ async def test_100_continue(server, backend):
146146

147147
assert response.status_code == 200
148148
assert response.content == data
149+
150+
151+
async def test_uds(uds_server, backend):
152+
url = uds_server.url
153+
uds = uds_server.config.uds
154+
assert uds is not None
155+
async with httpx.AsyncClient(backend=backend, uds=uds) as client:
156+
response = await client.get(url)
157+
assert response.status_code == 200
158+
assert response.text == "Hello, world!"
159+
assert response.http_version == "HTTP/1.1"
160+
assert response.headers
161+
assert repr(response) == "<Response [200 OK]>"
162+
assert response.elapsed > timedelta(seconds=0)

tests/client/test_client.py

+19
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,25 @@ def test_base_url(server):
138138
assert response.url == base_url
139139

140140

141+
def test_uds(uds_server):
142+
url = uds_server.url
143+
uds = uds_server.config.uds
144+
assert uds is not None
145+
with httpx.Client(uds=uds) as http:
146+
response = http.get(url)
147+
assert response.status_code == 200
148+
assert response.url == url
149+
assert response.content == b"Hello, world!"
150+
assert response.text == "Hello, world!"
151+
assert response.http_version == "HTTP/1.1"
152+
assert response.encoding == "iso-8859-1"
153+
assert response.request.url == url
154+
assert response.headers
155+
assert response.is_redirect is False
156+
assert repr(response) == "<Response [200 OK]>"
157+
assert response.elapsed > timedelta(0)
158+
159+
141160
def test_merge_url():
142161
client = httpx.Client(base_url="https://www.paypal.com/")
143162
url = client.merge_url("http://www.paypal.com")

tests/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ def server():
288288
yield from serve_in_thread(server)
289289

290290

291+
@pytest.fixture(scope=SERVER_SCOPE)
292+
def uds_server():
293+
config = Config(app=app, lifespan="off", loop="asyncio", uds="test_server.sock")
294+
server = TestServer(config=config)
295+
yield from serve_in_thread(server)
296+
297+
291298
@pytest.fixture(scope=SERVER_SCOPE)
292299
def https_server(cert_pem_file, cert_private_key_file):
293300
config = Config(

0 commit comments

Comments
 (0)