Skip to content

Commit

Permalink
Support AnyIO
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Oct 31, 2024
1 parent 30e3189 commit b2760e6
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 385 deletions.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = ["cffi; implementation_name == 'pypy'"]
dependencies = [
"cffi; implementation_name == 'pypy'",
"anyioutils >=0.4.2"
]
description = "Python bindings for 0MQ"
readme = "README.md"

Expand Down Expand Up @@ -144,7 +147,7 @@ search = '__version__: str = "{current_version}"'
[tool.cibuildwheel]
build-verbosity = "1"
free-threaded-support = true
test-requires = ["pytest>=6", "importlib_metadata"]
test-requires = ["pytest>=6", "importlib_metadata", "exceptiongroup;python_version<'3.11'"]
test-command = "pytest -vsx {package}/tools/test_wheel.py"

[tool.cibuildwheel.linux]
Expand Down
220 changes: 108 additions & 112 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
from multiprocessing import Process

import pytest
from anyio import create_task_group, move_on_after, sleep
from anyioutils import CancelledError, create_task
from pytest import mark

import zmq
import zmq.asyncio as zaio

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup


pytestmark = pytest.mark.anyio


@pytest.fixture
def Context(event_loop):
Expand Down Expand Up @@ -46,23 +54,17 @@ def test_instance_subclass_second(context):
async def test_recv_multipart(context, create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b"hi"]
assert await f == [b"hi"]


async def test_recv(create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b"hi"
assert recvd == b"there"
assert await f1 == b"hi"
assert await f2 == b"there"


@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
Expand All @@ -72,82 +74,70 @@ async def test_recv_timeout(push_pull):
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
with pytest.raises(ExceptionGroup) as excinfo:
await f1
assert excinfo.group_contains(zmq.Again)
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b"hi", b"there"]


@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
async def test_send_timeout(socket):
s = socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
with pytest.raises(ExceptionGroup) as excinfo:
await s.send(b"not going anywhere")
assert excinfo.group_contains(zmq.Again)


async def test_recv_string(push_pull):
a, b = push_pull
f = b.recv_string()
assert not f.done()
msg = "πøøπ"
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg


async def test_recv_json(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj


async def test_recv_json_cancelled(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.recv_json(), tg)
f.cancel(raise_exception=False)
# cycle eventloop to allow cancel events to fire
await sleep(0)
obj = dict(a=5)
await a.send_json(obj)
recvd = await f.wait()
assert f.cancelled()
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
with move_on_after(5):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
assert recvd == obj


async def test_recv_pyobj(push_pull):
a, b = push_pull
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj


Expand Down Expand Up @@ -206,85 +196,90 @@ async def test_custom_serialize_error(dealer_router):
async def test_recv_dontwait(push_pull):
push, pull = push_pull
f = pull.recv(zmq.DONTWAIT)
with pytest.raises(zmq.Again):
with pytest.raises(BaseExceptionGroup) as excinfo:
await f
assert excinfo.group_contains(zmq.Again)
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
msg = await pull.recv(zmq.DONTWAIT)
assert msg == b"ping"


async def test_recv_cancel(push_pull):
a, b = push_pull
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
a, b = push_pull
f1 = create_task(b.recv(), tg)
f2 = create_task(b.recv_multipart(), tg)
f1.cancel(raise_exception=False)
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2.wait()
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]


async def test_poll(push_pull):
a, b = push_pull
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.poll(timeout=0), tg)
await sleep(0.01)
assert f.result() == 0

f = b.poll(timeout=1)
assert not f.done()
evt = await f
f = create_task(b.poll(timeout=1), tg)
assert not f.done()
evt = await f.wait()

assert evt == 0
assert evt == 0

f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
f = create_task(b.poll(timeout=1000), tg)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_base_socket(sockets):
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = create_task(poller.poll(timeout=1000), tg)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_on_closed_socket(push_pull):
a, b = push_pull
with pytest.raises(BaseExceptionGroup) as excinfo:
async with create_task_group() as tg:
a, b = push_pull

f = b.poll(timeout=1)
b.close()
f = create_task(b.poll(timeout=1), tg)
b.close()

# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await sleep(0)
if f.cancelled():
break
assert f.done()
assert excinfo.group_contains(zmq.error.ZMQError)


@pytest.mark.skipif(
Expand Down Expand Up @@ -344,16 +339,17 @@ def test_shadow():


async def test_poll_leak():
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
f.cancel()
await asyncio.sleep(0)
# one more sleep allows further chained cleanup
await asyncio.sleep(0.1)
assert len(s._recv_futures) == 0
async with create_task_group() as tg:
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = create_task(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN), tg)
f.cancel(raise_exception=False)
await sleep(0)
# one more sleep allows further chained cleanup
await sleep(0.1)
assert len(s._recv_futures) == 0


class ProcessForTeardownTest(Process):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_ioloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
_tornado = True


def setup():
if not _tornado:
pytest.skip("requires tornado")
if not _tornado:
pytest.skip("requires tornado", allow_module_level=True)


def test_ioloop():
Expand Down
Loading

0 comments on commit b2760e6

Please sign in to comment.