Skip to content

Commit def7a6d

Browse files
committed
Tests + linting
1 parent cc8ed3f commit def7a6d

File tree

4 files changed

+30
-26
lines changed

4 files changed

+30
-26
lines changed

docs/code_examples/http_multipart_async.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
async def main():
1111

12-
transport = HTTPMultipartTransport(
13-
url="http://localhost:8000/graphql"
14-
)
12+
transport = HTTPMultipartTransport(url="http://localhost:8000/graphql")
1513

1614
# Using `async with` on the client will start a connection on the transport
1715
# and provide a `session` variable to execute queries on this connection

gql/transport/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .async_transport import AsyncTransport
2-
from .http_multipart_transport import HTTPMultipartTransport
32
from .transport import Transport
43

5-
__all__ = ["AsyncTransport", "HTTPMultipartTransport", "Transport"]
4+
__all__ = ["AsyncTransport", "Transport"]

gql/transport/http_multipart_transport.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
the multipart subscription protocol as implemented by Apollo GraphOS Router
66
and other compatible servers.
77
8-
Reference: https://www.apollographql.com/docs/graphos/routing/operations/subscriptions/multipart-protocol
9-
Issue: https://github.com/graphql-python/gql/issues/463
8+
Reference:
9+
https://www.apollographql.com/docs/graphos/routing/operations/subscriptions/multipart-protocol
1010
"""
1111

1212
import asyncio
1313
import json
1414
import logging
1515
from ssl import SSLContext
16-
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Union
16+
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union
1717

1818
import aiohttp
1919
from aiohttp.client_reqrep import Fingerprint
@@ -92,7 +92,7 @@ async def connect(self) -> None:
9292
if self.session is not None:
9393
raise TransportAlreadyConnected("Transport is already connected")
9494

95-
client_session_args = {
95+
client_session_args: Dict[str, Any] = {
9696
"cookies": self.cookies,
9797
"headers": self.headers,
9898
"auth": self.auth,
@@ -170,7 +170,7 @@ async def subscribe(
170170
error_text = await response.text()
171171
raise TransportServerError(
172172
f"Server returned {response.status}: {error_text}",
173-
response.status
173+
response.status,
174174
)
175175

176176
content_type = response.headers.get("Content-Type", "")
@@ -183,7 +183,9 @@ async def subscribe(
183183
)
184184

185185
# Parse multipart response
186-
async for result in self._parse_multipart_response(response, content_type):
186+
async for result in self._parse_multipart_response(
187+
response, content_type
188+
):
187189
yield result
188190

189191
except (TransportServerError, TransportProtocolError):
@@ -233,20 +235,24 @@ async def _parse_multipart_response(
233235
break # No complete part yet
234236

235237
# Check if this is the end boundary
236-
if buffer[boundary_pos:boundary_pos + len(end_boundary_bytes)] == end_boundary_bytes:
238+
end_pos = boundary_pos + len(end_boundary_bytes)
239+
if buffer[boundary_pos:end_pos] == end_boundary_bytes:
237240
log.debug("Reached end boundary")
238241
return
239242

240243
# Find the start of the next part (after this boundary)
241244
# Look for either another regular boundary or the end boundary
242-
next_boundary_pos = buffer.find(boundary_bytes, boundary_pos + len(boundary_bytes))
245+
next_boundary_pos = buffer.find(
246+
boundary_bytes, boundary_pos + len(boundary_bytes)
247+
)
243248

244249
if next_boundary_pos == -1:
245250
# No next boundary yet, wait for more data
246251
break
247252

248253
# Extract the part between boundaries
249-
part_data = buffer[boundary_pos + len(boundary_bytes):next_boundary_pos]
254+
start_pos = boundary_pos + len(boundary_bytes)
255+
part_data = buffer[start_pos:next_boundary_pos]
250256

251257
# Parse the part
252258
try:
@@ -270,16 +276,16 @@ def _parse_multipart_part(self, part_data: bytes) -> Optional[ExecutionResult]:
270276
:return: ExecutionResult or None if part is empty/heartbeat
271277
"""
272278
# Split headers and body by double CRLF or double LF
273-
part_str = part_data.decode('utf-8')
279+
part_str = part_data.decode("utf-8")
274280

275281
# Try different separators
276-
if '\r\n\r\n' in part_str:
277-
parts = part_str.split('\r\n\r\n', 1)
278-
elif '\n\n' in part_str:
279-
parts = part_str.split('\n\n', 1)
282+
if "\r\n\r\n" in part_str:
283+
parts = part_str.split("\r\n\r\n", 1)
284+
elif "\n\n" in part_str:
285+
parts = part_str.split("\n\n", 1)
280286
else:
281287
# No headers separator found, treat entire content as body
282-
parts = ['', part_str]
288+
parts = ["", part_str]
283289

284290
if len(parts) < 2:
285291
return None

tests/test_http_multipart_transport.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
from typing import Mapping
44

5-
import aiohttp
65
import pytest
76

87
from gql import Client, gql
@@ -210,7 +209,7 @@ async def handler(request):
210209
# Transport error has null payload with errors at top level
211210
error_response = {
212211
"payload": None,
213-
"errors": [{"message": "Transport connection failed"}]
212+
"errors": [{"message": "Transport connection failed"}],
214213
}
215214
part = (
216215
f"--graphql\r\n"
@@ -254,7 +253,7 @@ async def handler(request):
254253
response = {
255254
"payload": {
256255
"data": {"book": book1},
257-
"errors": [{"message": "Field deprecated", "path": ["book", "author"]}]
256+
"errors": [{"message": "Field deprecated", "path": ["book", "author"]}],
258257
}
259258
}
260259
part = (
@@ -472,9 +471,9 @@ async def test_http_multipart_accept_header(aiohttp_server):
472471
async def handler(request):
473472
# Verify the Accept header
474473
accept_header = request.headers.get("Accept", "")
475-
assert 'multipart/mixed' in accept_header
474+
assert "multipart/mixed" in accept_header
476475
assert 'subscriptionSpec="1.0"' in accept_header
477-
assert 'application/json' in accept_header
476+
assert "application/json" in accept_header
478477

479478
body = create_multipart_response([book1])
480479
return web.Response(
@@ -617,7 +616,9 @@ async def test_http_multipart_connection_error():
617616
from gql.transport.http_multipart_transport import HTTPMultipartTransport
618617

619618
# Use an invalid URL that will fail to connect
620-
transport = HTTPMultipartTransport(url="http://invalid.local:99999/graphql", timeout=1)
619+
transport = HTTPMultipartTransport(
620+
url="http://invalid.local:99999/graphql", timeout=1
621+
)
621622

622623
async with Client(transport=transport) as session:
623624
query = gql(subscription_str)

0 commit comments

Comments
 (0)