|
1 | | -import asyncio |
2 | 1 | import logging |
3 | 2 | from typing import List, Optional, Any, Dict, Iterable |
4 | 3 |
|
@@ -234,102 +233,135 @@ async def view_messages_list(request: web.Request) -> web.Response: |
234 | 233 | ) |
235 | 234 |
|
236 | 235 |
|
237 | | -async def _message_ws_read_from_queue( |
| 236 | +async def _send_history_to_ws( |
238 | 237 | ws: aiohttp.web_ws.WebSocketResponse, |
239 | | - mq_queue: aio_pika.abc.AbstractQueue, |
240 | | - request: web.Request, |
| 238 | + session_factory: DbSessionFactory, |
| 239 | + history: int, |
| 240 | + message_filters: Dict[str, Any], |
241 | 241 | ) -> None: |
| 242 | + |
| 243 | + with session_factory() as session: |
| 244 | + messages = get_matching_messages( |
| 245 | + session=session, |
| 246 | + pagination=history, |
| 247 | + include_confirmations=True, |
| 248 | + **message_filters, |
| 249 | + ) |
| 250 | + for message in messages: |
| 251 | + await ws.send_str(format_message(message).json()) |
| 252 | + |
| 253 | + |
| 254 | +async def _start_mq_consumer( |
| 255 | + ws: aiohttp.web_ws.WebSocketResponse, |
| 256 | + mq_queue: aio_pika.abc.AbstractQueue, |
| 257 | + session_factory: DbSessionFactory, |
| 258 | + message_filters: Dict[str, Any], |
| 259 | +) -> aio_pika.abc.ConsumerTag: |
242 | 260 | """ |
243 | | - Task receiving new aleph.im messages from the processing pipeline to a websocket. |
| 261 | + Starts the consumer task responsible for forwarding new aleph.im messages from |
| 262 | + the processing pipeline to a websocket. |
244 | 263 |
|
245 | 264 | :param ws: Websocket. |
246 | 265 | :param mq_queue: Message queue object. |
247 | | - :param request: Websocket HTTP request object. |
| 266 | + :param session_factory: DB session factory. |
| 267 | + :param message_filters: Filters to apply to select outgoing messages. |
248 | 268 | """ |
249 | 269 |
|
250 | | - query_params = WsMessageQueryParams.parse_obj(request.query) |
251 | | - session_factory = get_session_factory_from_request(request) |
252 | | - |
253 | | - find_filters = query_params.dict(exclude_none=True) |
254 | | - history = query_params.history |
255 | | - |
256 | | - if history: |
| 270 | + async def _process_message(mq_message: aio_pika.abc.AbstractMessage): |
| 271 | + item_hash = aleph_json.loads(mq_message.body)["item_hash"] |
| 272 | + # A bastardized way to apply the filters on the message as well. |
| 273 | + # TODO: put the filter key/values in the RabbitMQ message? |
257 | 274 | with session_factory() as session: |
258 | | - messages = get_matching_messages( |
| 275 | + matching_messages = get_matching_messages( |
259 | 276 | session=session, |
260 | | - pagination=history, |
| 277 | + hashes=[item_hash], |
261 | 278 | include_confirmations=True, |
262 | | - **find_filters, |
| 279 | + **message_filters, |
263 | 280 | ) |
264 | | - for message in messages: |
265 | | - await ws.send_str(format_message(message).json()) |
266 | | - |
267 | | - try: |
268 | | - async with mq_queue.iterator(no_ack=True) as queue_iter: |
269 | | - async for mq_message in queue_iter: |
270 | | - item_hash = aleph_json.loads(mq_message.body)["item_hash"] |
271 | | - # A bastardized way to apply the filters on the message as well. |
272 | | - # TODO: put the filter key/values in the RabbitMQ message? |
273 | | - with session_factory() as session: |
274 | | - matching_messages = get_matching_messages( |
275 | | - session=session, |
276 | | - hashes=[item_hash], |
277 | | - include_confirmations=True, |
278 | | - **find_filters, |
279 | | - ) |
280 | | - for message in matching_messages: |
281 | | - await ws.send_str(format_message(message).json()) |
282 | | - |
283 | | - except ConnectionResetError: |
284 | | - # We can detect the WS closing in this task in addition to the main one. |
285 | | - # warning. The main task will also detect the close event. |
286 | | - # We ignore this exception to avoid the "task exception was never retrieved" |
287 | | - LOGGER.info("Cannot send messages because the websocket is closed") |
288 | | - pass |
289 | | - |
290 | | - except asyncio.CancelledError: |
291 | | - LOGGER.info("MQ -> WS task cancelled") |
292 | | - raise |
| 281 | + try: |
| 282 | + for message in matching_messages: |
| 283 | + await ws.send_str(format_message(message).json()) |
| 284 | + except ConnectionResetError: |
| 285 | + # We can detect the WS closing in this task in addition to the main one. |
| 286 | + # The main task will also detect the close event. |
| 287 | + # We just ignore this exception to avoid the "task exception was never retrieved" |
| 288 | + # warning. |
| 289 | + LOGGER.info("Cannot send messages because the websocket is closed") |
| 290 | + |
| 291 | + # Note that we use the consume pattern here instead of using the `queue.iterator()` |
| 292 | + # pattern because cancelling the iterator attempts to close the queue and channel. |
| 293 | + # See discussion here: https://github.com/mosquito/aio-pika/issues/358 |
| 294 | + consumer_tag = await mq_queue.consume(_process_message, no_ack=True) |
| 295 | + return consumer_tag |
293 | 296 |
|
294 | 297 |
|
295 | 298 | async def messages_ws(request: web.Request) -> web.WebSocketResponse: |
296 | 299 | ws = web.WebSocketResponse() |
297 | 300 | await ws.prepare(request) |
298 | 301 |
|
299 | 302 | config = get_config_from_request(request) |
| 303 | + session_factory = get_session_factory_from_request(request) |
300 | 304 | mq_channel = get_mq_channel_from_request(request) |
301 | 305 |
|
| 306 | + try: |
| 307 | + query_params = WsMessageQueryParams.parse_obj(request.query) |
| 308 | + except ValidationError as e: |
| 309 | + raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) |
| 310 | + message_filters = query_params.dict(exclude_none=True) |
| 311 | + history = query_params.history |
| 312 | + |
| 313 | + if history: |
| 314 | + try: |
| 315 | + await _send_history_to_ws( |
| 316 | + ws=ws, |
| 317 | + session_factory=session_factory, |
| 318 | + history=history, |
| 319 | + message_filters=message_filters, |
| 320 | + ) |
| 321 | + except ConnectionResetError: |
| 322 | + LOGGER.info("Could not send history, aborting message websocket") |
| 323 | + return ws |
| 324 | + |
302 | 325 | mq_queue = await mq_make_aleph_message_topic_queue( |
303 | 326 | channel=mq_channel, config=config, routing_key="processed.*" |
304 | 327 | ) |
| 328 | + consumer_tag = None |
305 | 329 |
|
306 | | - # Start a task to handle outgoing traffic to the websocket. |
307 | | - queue_task = asyncio.create_task( |
308 | | - _message_ws_read_from_queue( |
| 330 | + try: |
| 331 | + # Start a task to handle outgoing traffic to the websocket. |
| 332 | + consumer_tag = await _start_mq_consumer( |
309 | 333 | ws=ws, |
310 | | - request=request, |
311 | 334 | mq_queue=mq_queue, |
| 335 | + session_factory=session_factory, |
| 336 | + message_filters=message_filters, |
| 337 | + ) |
| 338 | + LOGGER.debug( |
| 339 | + "Started consuming mq %s for websocket. Consumer tag: %s", |
| 340 | + mq_queue.name, |
| 341 | + consumer_tag, |
312 | 342 | ) |
313 | | - ) |
314 | 343 |
|
315 | | - # Wait for the websocket to close. |
316 | | - try: |
| 344 | + # Wait for the websocket to close. |
317 | 345 | while not ws.closed: |
318 | 346 | # Users can potentially send anything to the websocket. Ignore these messages |
319 | 347 | # and only handle "close" messages. |
320 | 348 | ws_msg = await ws.receive() |
321 | | - LOGGER.info("rx ws msg: %s", str(ws_msg)) |
| 349 | + LOGGER.debug("rx ws msg: %s", str(ws_msg)) |
322 | 350 | if ws_msg.type == WSMsgType.CLOSE: |
323 | 351 | LOGGER.debug("ws close received") |
324 | 352 | break |
325 | 353 |
|
326 | 354 | finally: |
327 | | - # Cancel the MQ -> ws task |
328 | | - queue_task.cancel() |
329 | | - await asyncio.wait([queue_task]) |
330 | | - |
331 | | - # Always delete the queue, auto-delete queues are only deleted once the channel is closed |
332 | | - # and that's not meant to happen for the API. |
| 355 | + # In theory, we should cancel the consumer with `mq_queue.cancel()` before deleting the queue. |
| 356 | + # In practice, this sometimes leads to an RPC timeout that closes the channel. |
| 357 | + # To avoid this situation, we just delete the queue directly. |
| 358 | + # Note that even if the queue is in auto-delete mode, it will only be deleted automatically |
| 359 | + # once the channel closes. We delete it manually to avoid keeping queues around. |
| 360 | + if consumer_tag: |
| 361 | + LOGGER.info("Deleting consumer %s (queue: %s)", consumer_tag, mq_queue.name) |
| 362 | + await mq_queue.cancel(consumer_tag=consumer_tag) |
| 363 | + |
| 364 | + LOGGER.info("Deleting queue: %s", mq_queue.name) |
333 | 365 | await mq_queue.delete(if_unused=False, if_empty=False) |
334 | 366 |
|
335 | 367 | return ws |
|
0 commit comments