Skip to content

Commit 1f09e33

Browse files
authored
Fix: avoid channel closed exceptions in message websocket (#429)
Problem: upon closing, the message websocket may attempt to close the RabbitMQ channel. This appears to be caused by the handling of `CancellationError` in the `close()` method of the queue iterator. Solution: replace the `iterator` implementation with a simpler `consume` + callback implementation. This guarantees that the secondary task will never attempt to close the channel. Replace task cancellation by a call to `queue.cancel()` as aiormq handles cancellation exceptions by closing the channel.
1 parent 03e6400 commit 1f09e33

File tree

3 files changed

+111
-70
lines changed

3 files changed

+111
-70
lines changed

src/aleph/web/controllers/messages.py

Lines changed: 91 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import logging
32
from typing import List, Optional, Any, Dict, Iterable
43

@@ -234,102 +233,135 @@ async def view_messages_list(request: web.Request) -> web.Response:
234233
)
235234

236235

237-
async def _message_ws_read_from_queue(
236+
async def _send_history_to_ws(
238237
ws: aiohttp.web_ws.WebSocketResponse,
239-
mq_queue: aio_pika.abc.AbstractQueue,
240-
request: web.Request,
238+
session_factory: DbSessionFactory,
239+
history: int,
240+
message_filters: Dict[str, Any],
241241
) -> None:
242+
243+
with session_factory() as session:
244+
messages = get_matching_messages(
245+
session=session,
246+
pagination=history,
247+
include_confirmations=True,
248+
**message_filters,
249+
)
250+
for message in messages:
251+
await ws.send_str(format_message(message).json())
252+
253+
254+
async def _start_mq_consumer(
255+
ws: aiohttp.web_ws.WebSocketResponse,
256+
mq_queue: aio_pika.abc.AbstractQueue,
257+
session_factory: DbSessionFactory,
258+
message_filters: Dict[str, Any],
259+
) -> aio_pika.abc.ConsumerTag:
242260
"""
243-
Task receiving new aleph.im messages from the processing pipeline to a websocket.
261+
Starts the consumer task responsible for forwarding new aleph.im messages from
262+
the processing pipeline to a websocket.
244263
245264
:param ws: Websocket.
246265
:param mq_queue: Message queue object.
247-
:param request: Websocket HTTP request object.
266+
:param session_factory: DB session factory.
267+
:param message_filters: Filters to apply to select outgoing messages.
248268
"""
249269

250-
query_params = WsMessageQueryParams.parse_obj(request.query)
251-
session_factory = get_session_factory_from_request(request)
252-
253-
find_filters = query_params.dict(exclude_none=True)
254-
history = query_params.history
255-
256-
if history:
270+
async def _process_message(mq_message: aio_pika.abc.AbstractMessage):
271+
item_hash = aleph_json.loads(mq_message.body)["item_hash"]
272+
# A bastardized way to apply the filters on the message as well.
273+
# TODO: put the filter key/values in the RabbitMQ message?
257274
with session_factory() as session:
258-
messages = get_matching_messages(
275+
matching_messages = get_matching_messages(
259276
session=session,
260-
pagination=history,
277+
hashes=[item_hash],
261278
include_confirmations=True,
262-
**find_filters,
279+
**message_filters,
263280
)
264-
for message in messages:
265-
await ws.send_str(format_message(message).json())
266-
267-
try:
268-
async with mq_queue.iterator(no_ack=True) as queue_iter:
269-
async for mq_message in queue_iter:
270-
item_hash = aleph_json.loads(mq_message.body)["item_hash"]
271-
# A bastardized way to apply the filters on the message as well.
272-
# TODO: put the filter key/values in the RabbitMQ message?
273-
with session_factory() as session:
274-
matching_messages = get_matching_messages(
275-
session=session,
276-
hashes=[item_hash],
277-
include_confirmations=True,
278-
**find_filters,
279-
)
280-
for message in matching_messages:
281-
await ws.send_str(format_message(message).json())
282-
283-
except ConnectionResetError:
284-
# We can detect the WS closing in this task in addition to the main one.
285-
# warning. The main task will also detect the close event.
286-
# We ignore this exception to avoid the "task exception was never retrieved"
287-
LOGGER.info("Cannot send messages because the websocket is closed")
288-
pass
289-
290-
except asyncio.CancelledError:
291-
LOGGER.info("MQ -> WS task cancelled")
292-
raise
281+
try:
282+
for message in matching_messages:
283+
await ws.send_str(format_message(message).json())
284+
except ConnectionResetError:
285+
# We can detect the WS closing in this task in addition to the main one.
286+
# The main task will also detect the close event.
287+
# We just ignore this exception to avoid the "task exception was never retrieved"
288+
# warning.
289+
LOGGER.info("Cannot send messages because the websocket is closed")
290+
291+
# Note that we use the consume pattern here instead of using the `queue.iterator()`
292+
# pattern because cancelling the iterator attempts to close the queue and channel.
293+
# See discussion here: https://github.com/mosquito/aio-pika/issues/358
294+
consumer_tag = await mq_queue.consume(_process_message, no_ack=True)
295+
return consumer_tag
293296

294297

295298
async def messages_ws(request: web.Request) -> web.WebSocketResponse:
296299
ws = web.WebSocketResponse()
297300
await ws.prepare(request)
298301

299302
config = get_config_from_request(request)
303+
session_factory = get_session_factory_from_request(request)
300304
mq_channel = get_mq_channel_from_request(request)
301305

306+
try:
307+
query_params = WsMessageQueryParams.parse_obj(request.query)
308+
except ValidationError as e:
309+
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))
310+
message_filters = query_params.dict(exclude_none=True)
311+
history = query_params.history
312+
313+
if history:
314+
try:
315+
await _send_history_to_ws(
316+
ws=ws,
317+
session_factory=session_factory,
318+
history=history,
319+
message_filters=message_filters,
320+
)
321+
except ConnectionResetError:
322+
LOGGER.info("Could not send history, aborting message websocket")
323+
return ws
324+
302325
mq_queue = await mq_make_aleph_message_topic_queue(
303326
channel=mq_channel, config=config, routing_key="processed.*"
304327
)
328+
consumer_tag = None
305329

306-
# Start a task to handle outgoing traffic to the websocket.
307-
queue_task = asyncio.create_task(
308-
_message_ws_read_from_queue(
330+
try:
331+
# Start a task to handle outgoing traffic to the websocket.
332+
consumer_tag = await _start_mq_consumer(
309333
ws=ws,
310-
request=request,
311334
mq_queue=mq_queue,
335+
session_factory=session_factory,
336+
message_filters=message_filters,
337+
)
338+
LOGGER.debug(
339+
"Started consuming mq %s for websocket. Consumer tag: %s",
340+
mq_queue.name,
341+
consumer_tag,
312342
)
313-
)
314343

315-
# Wait for the websocket to close.
316-
try:
344+
# Wait for the websocket to close.
317345
while not ws.closed:
318346
# Users can potentially send anything to the websocket. Ignore these messages
319347
# and only handle "close" messages.
320348
ws_msg = await ws.receive()
321-
LOGGER.info("rx ws msg: %s", str(ws_msg))
349+
LOGGER.debug("rx ws msg: %s", str(ws_msg))
322350
if ws_msg.type == WSMsgType.CLOSE:
323351
LOGGER.debug("ws close received")
324352
break
325353

326354
finally:
327-
# Cancel the MQ -> ws task
328-
queue_task.cancel()
329-
await asyncio.wait([queue_task])
330-
331-
# Always delete the queue, auto-delete queues are only deleted once the channel is closed
332-
# and that's not meant to happen for the API.
355+
# In theory, we should cancel the consumer with `mq_queue.cancel()` before deleting the queue.
356+
# In practice, this sometimes leads to an RPC timeout that closes the channel.
357+
# To avoid this situation, we just delete the queue directly.
358+
# Note that even if the queue is in auto-delete mode, it will only be deleted automatically
359+
# once the channel closes. We delete it manually to avoid keeping queues around.
360+
if consumer_tag:
361+
LOGGER.info("Deleting consumer %s (queue: %s)", consumer_tag, mq_queue.name)
362+
await mq_queue.cancel(consumer_tag=consumer_tag)
363+
364+
LOGGER.info("Deleting queue: %s", mq_queue.name)
333365
await mq_queue.delete(if_unused=False, if_empty=False)
334366

335367
return ws

src/aleph/web/controllers/p2p.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,25 @@ async def pub_json(request: web.Request):
142142

143143

144144
async def _mq_read_one_message(
145-
queue: aio_pika.abc.AbstractQueue, timeout: float
145+
mq_queue: aio_pika.abc.AbstractQueue, timeout: float
146146
) -> Optional[aio_pika.abc.AbstractIncomingMessage]:
147147
"""
148-
Believe it or not, this is the only way I found to
149-
:return:
148+
Consume one element from a message queue and then return.
150149
"""
151-
try:
152-
async with queue.iterator(timeout=timeout, no_ack=True) as queue_iter:
153-
async for message in queue_iter:
154-
return message
155150

156-
except asyncio.TimeoutError:
157-
pass
151+
queue: asyncio.Queue = asyncio.Queue()
152+
153+
async def _process_message(message: aio_pika.abc.AbstractMessage):
154+
await queue.put(message)
158155

159-
return None
156+
consumer_tag = await mq_queue.consume(_process_message, no_ack=True)
157+
158+
try:
159+
return await asyncio.wait_for(queue.get(), timeout)
160+
except asyncio.TimeoutError:
161+
return None
162+
finally:
163+
await mq_queue.cancel(consumer_tag)
160164

161165

162166
def _processing_status_to_http_status(status: MessageProcessingStatus) -> int:

src/aleph/web/controllers/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ async def mq_make_aleph_message_topic_queue(
152152
type=aio_pika.ExchangeType.TOPIC,
153153
auto_delete=False,
154154
)
155-
mq_queue = await channel.declare_queue(auto_delete=True)
155+
mq_queue = await channel.declare_queue(
156+
auto_delete=True, exclusive=True,
157+
# Auto-delete the queue after 30 seconds. This guarantees that queues are deleted even
158+
# if a bug makes the consumer crash before cleanup.
159+
arguments={"x-expires": 30000}
160+
)
156161
await mq_queue.bind(mq_message_exchange, routing_key=routing_key)
157162
return mq_queue

0 commit comments

Comments
 (0)