Skip to content

Commit b957a74

Browse files
robsdedudeMaxAake
andauthored
Fix coroutine not awaited warning (#1129)
When a certain deadline passed, the driver will not even attempt certain IO operations. By not calling the coroutine function to then ignore it under these circumstances (causing the warning), but instead deferring creation of the coroutine until it is clear we want to await it, we avoid the warning. Functionally, this changes nothing, but makes the driver less noise (and probably a better async citizen 🏅). Co-authored-by: MaxAake <[email protected]>
1 parent 2a1a772 commit b957a74

File tree

2 files changed

+169
-7
lines changed

2 files changed

+169
-7
lines changed

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, reader, protocol, writer):
9898
self._timeout = None
9999
self._deadline = None
100100

101-
async def _wait_for_io(self, io_fut):
101+
async def _wait_for_io(self, io_async_fn, *args, **kwargs):
102102
timeout = self._timeout
103103
to_raise = SocketTimeout
104104
if self._deadline is not None:
@@ -109,6 +109,7 @@ async def _wait_for_io(self, io_fut):
109109
timeout = deadline_timeout
110110
to_raise = SocketDeadlineExceededError
111111

112+
io_fut = io_async_fn(*args, **kwargs)
112113
if timeout is not None and timeout <= 0:
113114
# give the io-operation time for one loop cycle to do its thing
114115
io_fut = asyncio.create_task(io_fut)
@@ -157,20 +158,17 @@ def settimeout(self, timeout):
157158
self._timeout = timeout
158159

159160
async def recv(self, n):
160-
io_fut = self._reader.read(n)
161-
return await self._wait_for_io(io_fut)
161+
return await self._wait_for_io(self._reader.read, n)
162162

163163
async def recv_into(self, buffer, nbytes):
164164
# FIXME: not particularly memory or time efficient
165-
io_fut = self._reader.read(nbytes)
166-
res = await self._wait_for_io(io_fut)
165+
res = await self._wait_for_io(self._reader.read, nbytes)
167166
buffer[: len(res)] = res
168167
return len(res)
169168

170169
async def sendall(self, data):
171170
self._writer.write(data)
172-
io_fut = self._writer.drain()
173-
return await self._wait_for_io(io_fut)
171+
return await self._wait_for_io(self._writer.drain)
174172

175173
async def close(self):
176174
self._writer.close()
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from __future__ import annotations
18+
19+
import asyncio
20+
import socket
21+
import typing as t
22+
23+
import freezegun
24+
import pytest
25+
26+
from neo4j._async_compat.network import AsyncBoltSocket
27+
from neo4j._exceptions import SocketDeadlineExceededError
28+
29+
from ...._async_compat.mark_decorator import mark_async_test
30+
31+
32+
if t.TYPE_CHECKING:
33+
import typing_extensions as te
34+
from freezegun.api import (
35+
FrozenDateTimeFactory,
36+
StepTickTimeFactory,
37+
TickingDateTimeFactory,
38+
)
39+
40+
TFreezeTime: te.TypeAlias = (
41+
StepTickTimeFactory | TickingDateTimeFactory | FrozenDateTimeFactory
42+
)
43+
44+
45+
@pytest.fixture
46+
def reader_factory(mocker):
47+
def factory():
48+
return mocker.create_autospec(asyncio.StreamReader)
49+
50+
return factory
51+
52+
53+
@pytest.fixture
54+
def writer_factory(mocker):
55+
def factory():
56+
return mocker.create_autospec(asyncio.StreamWriter)
57+
58+
return factory
59+
60+
61+
@pytest.fixture
62+
def socket_factory(reader_factory, writer_factory):
63+
def factory():
64+
protocol = None
65+
return AsyncBoltSocket(reader_factory(), protocol, writer_factory())
66+
67+
return factory
68+
69+
70+
def reader(s: AsyncBoltSocket):
71+
return s._reader
72+
73+
74+
def writer(s: AsyncBoltSocket):
75+
return s._writer
76+
77+
78+
@pytest.mark.parametrize(
79+
("timeout", "deadline", "pre_tick", "tick", "exception"),
80+
(
81+
(None, None, 60 * 60 * 10, 60 * 60 * 10, None),
82+
# test timeout
83+
(5, None, 0, 4, None),
84+
# timeout is not affected by time passed before the call
85+
(5, None, 7, 4, None),
86+
(5, None, 0, 6, socket.timeout),
87+
# test deadline
88+
(None, 5, 0, 4, None),
89+
(None, 5, 2, 2, None),
90+
# deadline is affected by time passed before the call
91+
(None, 5, 2, 4, SocketDeadlineExceededError),
92+
(None, 5, 6, 0, SocketDeadlineExceededError),
93+
(None, 5, 0, 6, SocketDeadlineExceededError),
94+
# test combination
95+
(5, 5, 0, 4, None),
96+
(5, 5, 2, 2, None),
97+
# deadline triggered by time passed before
98+
(5, 5, 2, 4, SocketDeadlineExceededError),
99+
# the shorter one determines the error
100+
(4, 5, 0, 6, socket.timeout),
101+
(5, 4, 0, 6, SocketDeadlineExceededError),
102+
),
103+
)
104+
@pytest.mark.parametrize("method", ("recv", "recv_into", "sendall"))
105+
@mark_async_test
106+
async def test_async_bolt_socket_read_timeout(
107+
socket_factory, timeout, deadline, pre_tick, tick, exception, method
108+
):
109+
def make_read_side_effect(freeze_time: TFreezeTime):
110+
async def read_side_effect(n):
111+
assert n == 1
112+
freeze_time.tick(tick)
113+
for _ in range(10):
114+
await asyncio.sleep(0)
115+
return b"y"
116+
117+
return read_side_effect
118+
119+
def make_drain_side_effect(freeze_time: TFreezeTime):
120+
async def drain_side_effect():
121+
freeze_time.tick(tick)
122+
for _ in range(10):
123+
await asyncio.sleep(0)
124+
125+
return drain_side_effect
126+
127+
async def call_method(s: AsyncBoltSocket):
128+
if method == "recv":
129+
res = await s.recv(1)
130+
assert res == b"y"
131+
elif method == "recv_into":
132+
b = bytearray(1)
133+
await s.recv_into(b, 1)
134+
assert b == b"y"
135+
elif method == "sendall":
136+
await s.sendall(b"y")
137+
else:
138+
raise NotImplementedError(f"method: {method}")
139+
140+
with freezegun.freeze_time("1970-01-01T00:00:00") as frozen_time:
141+
socket = socket_factory()
142+
if timeout is not None:
143+
socket.settimeout(timeout)
144+
if deadline is not None:
145+
socket.set_deadline(deadline)
146+
if pre_tick:
147+
frozen_time.tick(pre_tick)
148+
149+
if method in {"recv", "recv_into"}:
150+
reader(socket).read.side_effect = make_read_side_effect(
151+
frozen_time
152+
)
153+
elif method == "sendall":
154+
writer(socket).drain.side_effect = make_drain_side_effect(
155+
frozen_time
156+
)
157+
else:
158+
raise NotImplementedError(f"method: {method}")
159+
160+
if exception:
161+
with pytest.raises(exception):
162+
await call_method(socket)
163+
else:
164+
await call_method(socket)

0 commit comments

Comments
 (0)