Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testsuite: introduce asyncio socket #889

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions samples/tcp_full_duplex_service/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,3 @@ def service_non_http_health_checks(
checks.tcp.append(net.HostPort(host='localhost', port=tcp_service_port))
return checks
# /// [service_non_http_health_checker]


@pytest.fixture(name='asyncio_loop')
async def _asyncio_loop():
return asyncio.get_running_loop()
76 changes: 38 additions & 38 deletions samples/tcp_full_duplex_service/tests/test_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ def monitor_port(service_port) -> int:
return service_port


async def send_all_data(sock, loop):
async def send_all_data(sock):
for data in DATA:
await loop.sock_sendall(sock, data)
await sock.sendall(data)


async def recv_all_data(sock, loop):
async def recv_all_data(sock):
answer = b''
while len(answer) < DATA_LENGTH:
answer += await loop.sock_recv(sock, DATA_LENGTH - len(answer))
answer += await sock.recv(DATA_LENGTH - len(answer))

assert answer == b''.join(DATA)


async def test_basic(service_client, asyncio_loop, monitor_client, tcp_service_port):
async def test_basic(service_client, asyncio_socket, monitor_client, tcp_service_port):
await service_client.reset_metrics()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, ('localhost', tcp_service_port))
sock = asyncio_socket.tcp()
await sock.connect(('localhost', tcp_service_port))

send_task = asyncio.create_task(send_all_data(sock, asyncio_loop))
await recv_all_data(sock, asyncio_loop)
send_task = asyncio.create_task(send_all_data(sock))
await recv_all_data(sock)
await send_task
metrics = await monitor_client.metrics(prefix='tcp-echo.')
assert metrics.value_at('tcp-echo.sockets.opened') == 1
Expand All @@ -62,19 +62,19 @@ async def _gate(tcp_service_port):
yield proxy


async def test_delay_recv(service_client, asyncio_loop, monitor_client, gate):
async def test_delay_recv(service_client, asyncio_socket, monitor_client, gate):
await service_client.reset_metrics()
timeout = 10.0
timeout = 2.0

# respond with delay in TIMEOUT seconds
gate.to_client_delay(timeout)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, gate.get_sockname_for_clients())
sock = asyncio_socket.tcp()
await sock.connect(gate.get_sockname_for_clients())
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

recv_task = asyncio.create_task(recv_all_data(sock, asyncio_loop))
await send_all_data(sock, asyncio_loop)
recv_task = asyncio.create_task(recv_all_data(sock))
await send_all_data(sock)

done, _ = await asyncio.wait(
[recv_task],
Expand All @@ -92,16 +92,16 @@ async def test_delay_recv(service_client, asyncio_loop, monitor_client, gate):
assert metrics.value_at('tcp-echo.bytes.read') == DATA_LENGTH


async def test_data_combine(service_client, asyncio_loop, monitor_client, gate):
async def test_data_combine(asyncio_socket, service_client, monitor_client, gate):
await service_client.reset_metrics()
gate.to_client_concat_packets(DATA_LENGTH)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, gate.get_sockname_for_clients())
sock = asyncio_socket.socket()
await sock.connect(gate.get_sockname_for_clients())
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

send_task = asyncio.create_task(send_all_data(sock, asyncio_loop))
await recv_all_data(sock, asyncio_loop)
send_task = asyncio.create_task(send_all_data(sock))
await recv_all_data(sock)
await send_task

gate.to_client_pass()
Expand All @@ -112,18 +112,18 @@ async def test_data_combine(service_client, asyncio_loop, monitor_client, gate):
assert metrics.value_at('tcp-echo.bytes.read') == DATA_LENGTH


async def test_down_pending_recv(service_client, asyncio_loop, gate):
async def test_down_pending_recv(service_client, asyncio_socket, gate):
gate.to_client_noop()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, gate.get_sockname_for_clients())
sock = asyncio_socket.tcp()
await sock.connect(gate.get_sockname_for_clients())
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

async def _recv_no_data():
answer = b''
try:
while True:
answer += await asyncio_loop.sock_recv(sock, 2)
answer += await sock.recv(2)
assert False
except Exception: # pylint: disable=broad-except
pass
Expand All @@ -132,7 +132,7 @@ async def _recv_no_data():

recv_task = asyncio.create_task(_recv_no_data())

await send_all_data(sock, asyncio_loop)
await send_all_data(sock)

await asyncio.wait(
[recv_task],
Expand All @@ -145,17 +145,17 @@ async def _recv_no_data():

gate.to_client_pass()

sock2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock2.connect(gate.get_sockname_for_clients())
await asyncio_loop.sock_sendall(sock2, b'hi')
hello = await asyncio_loop.sock_recv(sock2, 2)
sock2 = asyncio_socket.tcp()
await sock2.connect(gate.get_sockname_for_clients())
await sock2.sendall(b'hi')
hello = await sock2.recv(2)
assert hello == b'hi'
assert gate.connections_count() == 1


async def test_multiple_socks(
asyncio_socket,
service_client,
asyncio_loop,
monitor_client,
tcp_service_port,
):
Expand All @@ -164,11 +164,11 @@ async def test_multiple_socks(

tasks = []
for _ in range(sockets_count):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, ('localhost', tcp_service_port))
sock = asyncio_socket.tcp()
await sock.connect(('localhost', tcp_service_port))
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
tasks.append(asyncio.create_task(send_all_data(sock, asyncio_loop)))
tasks.append(asyncio.create_task(recv_all_data(sock, asyncio_loop)))
tasks.append(asyncio.create_task(send_all_data(sock)))
tasks.append(asyncio.create_task(recv_all_data(sock)))
await asyncio.gather(*tasks)

metrics = await monitor_client.metrics(prefix='tcp-echo.')
Expand All @@ -177,8 +177,8 @@ async def test_multiple_socks(


async def test_multiple_send_only(
asyncio_socket,
service_client,
asyncio_loop,
monitor_client,
tcp_service_port,
):
Expand All @@ -187,10 +187,10 @@ async def test_multiple_send_only(

tasks = []
for _ in range(sockets_count):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio_loop.sock_connect(sock, ('localhost', tcp_service_port))
sock = asyncio_socket.tcp()
await sock.connect(('localhost', tcp_service_port))
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
tasks.append(asyncio.create_task(send_all_data(sock, asyncio_loop)))
tasks.append(asyncio.create_task(send_all_data(sock)))
await asyncio.gather(*tasks)


Expand Down
5 changes: 0 additions & 5 deletions samples/tcp_service/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,3 @@ def service_non_http_health_checks(
checks.tcp.append(net.HostPort(host='localhost', port=tcp_service_port))
return checks
# /// [service_non_http_health_checker]


@pytest.fixture(name='asyncio_loop')
async def _asyncio_loop():
return asyncio.get_running_loop()
25 changes: 11 additions & 14 deletions samples/tcp_service/tests/test_chaos.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import socket

import pytest
from pytest_userver import chaos

_ALL_DATA = 512


@pytest.fixture(name='gate', scope='function')
async def _gate(tcp_service_port):
gate_config = chaos.GateRoute(
Expand All @@ -17,29 +14,29 @@ async def _gate(tcp_service_port):
yield proxy


async def test_chaos_concat_packets(service_client, asyncio_loop, gate):
async def test_chaos_concat_packets(asyncio_socket, service_client, gate):
gate.to_client_concat_packets(10)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(gate.get_sockname_for_clients())
sock = asyncio_socket.tcp()
await sock.connect(gate.get_sockname_for_clients())

await asyncio_loop.sock_sendall(sock, b'hi')
await asyncio_loop.sock_sendall(sock, b'hi')
await sock.sendall(b'hi')
await sock.sendall(b'hi')

hello = await asyncio_loop.sock_recv(sock, _ALL_DATA)
hello = await sock.recv(_ALL_DATA)
assert hello == b'hellohello'
assert gate.connections_count() == 1
gate.to_client_pass()
sock.close()


async def test_chaos_close_on_data(service_client, asyncio_loop, gate):
async def test_chaos_close_on_data(service_client, asyncio_socket, gate):
gate.to_client_close_on_data()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(gate.get_sockname_for_clients())
sock = asyncio_socket.tcp()
await sock.connect(gate.get_sockname_for_clients())

await asyncio_loop.sock_sendall(sock, b'hi')
hello = await asyncio_loop.sock_recv(sock, _ALL_DATA)
await sock.sendall(b'hi')
hello = await sock.recv(_ALL_DATA)
assert not hello
sock.close()
16 changes: 7 additions & 9 deletions samples/tcp_service/tests/test_tcp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# /// [Functional test]
import socket

import pytest


async def test_basic(service_client, asyncio_loop, tcp_service_port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', tcp_service_port))
async def test_basic(service_client, asyncio_socket, tcp_service_port):
sock = asyncio_socket.tcp()
await sock.connect(('localhost', tcp_service_port))

await asyncio_loop.sock_sendall(sock, b'hi')
hello = await asyncio_loop.sock_recv(sock, 5)
await sock.sendall(b'hi')
hello = await sock.recv(5)
assert hello == b'hello'

await asyncio_loop.sock_sendall(sock, b'whats up?')
await sock.sendall(b'whats up?')
with pytest.raises(ConnectionResetError):
await asyncio_loop.sock_recv(sock, 1)
await sock.recv(1)
# /// [Functional test]
121 changes: 121 additions & 0 deletions testsuite/pytest_plugins/pytest_userver/asyncio_socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# TODO: move to testsuite
import asyncio
import select
import socket
import typing

DEFAULT_TIMEOUT = 10.0


class AsyncioSocket:
def __init__(
self,
sock: socket.socket,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
timeout=DEFAULT_TIMEOUT,
):
if loop is None:
loop = asyncio.get_running_loop()
self._loop: asyncio.AbstractEventLoop = loop
self._sock: socket.socket = sock
self._default_timeout = timeout
sock.setblocking(False)

def __repr__(self):
return f'<AsyncioSocket for {self._sock}>'

@property
def type(self):
return self._sock.type

def fileno(self):
return self._sock.fileno()

async def connect(self, address, *, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_connect(self._sock, address)

async def send(self, data, *, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_send(self._sock, data)

async def sendto(self, *args, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_sendto(self._sock, *args)

async def sendall(self, data, *, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_sendall(self._sock, data)

async def recv(self, size, *, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_recv(self._sock, size)

async def recvfrom(self, *args, timeout=None):
async with self._timeout(timeout):
return await self._loop.sock_recvfrom(self._sock, *args)

async def accept(self, *, timeout=None):
async with self._timeout(timeout):
conn, address = await self._loop.sock_accept(self._sock)
return from_socket(conn), address

def bind(self, address):
return self._sock.bind(address)

def listen(self, *args):
return self._sock.listen(*args)

def getsockname(self):
return self._sock.getsockname()

def setsockopt(self, *args, **kwargs):
self._sock.setsockopt(*args, **kwargs)

def close(self):
self._sock.close()

def has_data(self) -> bool:
rlist, _, _ = select.select([self._sock], [], [], 0)
return bool(rlist)

def _timeout(self, timeout=None):
if timeout is None:
timeout = self._default_timeout
return asyncio.timeout(timeout)


class AsyncioSocketsFactory:
def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_running_loop()
self._loop = loop

def socket(self, *args, timeout=DEFAULT_TIMEOUT):
sock = socket.socket(*args)
return from_socket(sock, loop=self._loop, timeout=timeout)

def tcp(self, *, timeout=DEFAULT_TIMEOUT):
return self.socket(socket.AF_INET, socket.SOCK_STREAM, timeout=timeout)

def udp(self, *, timeout=DEFAULT_TIMEOUT):
return self.socket(socket.AF_INET, socket.SOCK_DGRAM, timeout=timeout)


def from_socket(
sock: typing.Union[socket.socket, AsyncioSocket], *, loop=None, timeout=DEFAULT_TIMEOUT,
) -> AsyncioSocket:
if isinstance(sock, AsyncioSocket):
return sock
return AsyncioSocket(sock, loop=loop, timeout=timeout)


def create_socket(*args, timeout=DEFAULT_TIMEOUT):
return AsyncioSocketsFactory().socket(*args, timeout=timeout)

def create_tcp_socket(*args, timeout=DEFAULT_TIMEOUT):
return AsyncioSocketsFactory().tcp(timeout=timeout)


def create_udp_socket(timeout=DEFAULT_TIMEOUT):
return AsyncioSocketsFactory().udp(timeout=timeout)
Loading
Loading