diff --git a/docs/reference/kombu.transport.rediscluster.rst b/docs/reference/kombu.transport.rediscluster.rst new file mode 100644 index 000000000..e05c0d34a --- /dev/null +++ b/docs/reference/kombu.transport.rediscluster.rst @@ -0,0 +1,24 @@ +======================================================== + Redis Cluster Transport - ``kombu.transport.rediscluster`` +======================================================== + +.. currentmodule:: kombu.transport.rediscluster + +.. automodule:: kombu.transport.rediscluster + + .. contents:: + :local: + + Transport + --------- + + .. autoclass:: Transport + :members: + :undoc-members: + + Channel + ------- + + .. autoclass:: Channel + :members: + :undoc-members: diff --git a/kombu/transport/__init__.py b/kombu/transport/__init__.py index 180a27b4b..1b948aad2 100644 --- a/kombu/transport/__init__.py +++ b/kombu/transport/__init__.py @@ -45,6 +45,7 @@ def supports_librabbitmq() -> bool | None: 'azureservicebus': 'kombu.transport.azureservicebus:Transport', 'pyro': 'kombu.transport.pyro:Transport', 'gcpubsub': 'kombu.transport.gcpubsub:Transport', + 'rediscluster': 'kombu.transport.rediscluster:Transport', } _transport_cache = {} diff --git a/kombu/transport/rediscluster.py b/kombu/transport/rediscluster.py new file mode 100644 index 000000000..70bdb68f4 --- /dev/null +++ b/kombu/transport/rediscluster.py @@ -0,0 +1,679 @@ +"""Redis cluster transport module for Kombu. + +Features +======== +* Type: Virtual +* Supports Direct: Yes +* Supports Topic: Yes +* Supports Fanout: Yes +* Supports Priority: Yes (If hash_tag is set) +* Supports TTL: No + +Connection String +================= +Connection string has the following format: + +.. code-block:: + + rediscluster://[USER:PASSWORD@]REDIS_CLUSTER_ADDRESS[:PORT][/VIRTUALHOST] + +Transport Options +================= +* ``sep`` +* ``ack_emulation``: (bool) If set to True transport will + simulate Acknowledge of AMQP protocol. +* ``unacked_key`` +* ``unacked_index_key`` +* ``unacked_mutex_key`` +* ``unacked_mutex_expire`` +* ``visibility_timeout`` +* ``unacked_restore_limit`` +* ``fanout_prefix`` +* ``fanout_patterns`` +* ``global_keyprefix``: (str) The global key prefix to be prepended to all keys +* ``hash_tag``: (str) Prefix keys with it, + effective at the same time as global_keyprefix.({hash_tag}{global_keyprefix}key) + used by Kombu +* ``socket_timeout`` +* ``socket_connect_timeout`` +* ``socket_keepalive`` +* ``socket_keepalive_options`` +* ``queue_order_strategy`` +* ``max_connections`` +* ``health_check_interval`` +* ``retry_on_timeout`` +* ``priority_steps`` +""" + +from __future__ import annotations + +import functools +from contextlib import contextmanager +from queue import Empty +from time import time + +from redis.exceptions import (AskError, MovedError, RedisClusterException, + TryAgainError) + +from kombu.exceptions import VersionMismatch +from kombu.log import get_logger +from kombu.transport import virtual +from kombu.transport.virtual.base import Channel as VirtualBaseChannel +from kombu.utils import uuid +from kombu.utils.compat import register_after_fork +from kombu.utils.encoding import bytes_to_str +from kombu.utils.eventio import ERR, READ +from kombu.utils.json import dumps, loads +from kombu.utils.objects import cached_property + +from ..utils.scheduling import cycle_by_name +from .redis import Channel as RedisChannel +from .redis import GlobalKeyPrefixMixin as RedisGlobalKeyPrefixMixin +from .redis import MultiChannelPoller as RedisMultiChannelPoller +from .redis import MutexHeld +from .redis import QoS as RedisQoS +from .redis import Transport as RedisTransport +from .redis import _after_fork_cleanup_channel + +try: + import redis +except ImportError: + redis = None + +logger = get_logger(__name__) +crit, warning = logger.critical, logger.warning + + +@contextmanager +def Mutex(client, name, expire): + """Acquire redis lock in non blocking way. Raise MutexHeld if not successful. + + The internal implementation of lock uses uuid as the key, so it cannot be used in cluster mode. Use setnx instead + """ + lock_id = uuid().encode('utf-8') + acquired = client.set(name, lock_id, ex=expire, nx=True) + try: + if acquired: + yield + else: + raise MutexHeld() + finally: + if acquired: + with client.pipeline() as pipe: + try: + pipe.watch(name) + if client.get(name) == lock_id: + pipe.multi() + pipe.delete(name) + pipe.execute() + return + pipe.unwatch() + except redis.exceptions.WatchError: + pass + + +class GlobalKeyPrefixMixin(RedisGlobalKeyPrefixMixin): + """Mixin to provide common logic for global key prefixing. + + copied from redis.cluster.RedisCluster.pipeline + """ + + def pipeline(self, transaction=False, shard_hint=None): + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return PrefixedRedisPipeline( + nodes_manager=self.nodes_manager, + commands_parser=self.commands_parser, + startup_nodes=self.nodes_manager.startup_nodes, + result_callbacks=self.result_callbacks, + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + read_from_replicas=self.read_from_replicas, + reinitialize_steps=self.reinitialize_steps, + lock=self._lock, + global_keyprefix=self.global_keyprefix, + ) + + +class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.RedisCluster): + """Returns a ``RedisCluster`` client that prefixes the keys it uses.""" + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + redis.RedisCluster.__init__(self, *args, **kwargs) + + def pubsub(self, **kwargs): + return PrefixedRedisPubSub( + self, + global_keyprefix=self.global_keyprefix, + **kwargs, + ) + + def keyslot(self, key): + return super().keyslot(f'{self.global_keyprefix}{key}') + + +class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.cluster.ClusterPipeline): + """Custom Redis cluster pipeline that takes global_keyprefix into consideration.""" + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + redis.cluster.ClusterPipeline.__init__(self, *args, **kwargs) + + +class PrefixedRedisPubSub(redis.cluster.ClusterPubSub): + """Redis cluster pubsub client that takes global_keyprefix into consideration.""" + + PUBSUB_COMMANDS = ( + "SUBSCRIBE", + "UNSUBSCRIBE", + "PSUBSCRIBE", + "PUNSUBSCRIBE", + ) + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + super().__init__(*args, **kwargs) + + def _prefix_args(self, args): + args = list(args) + command = args.pop(0) + + if command in self.PUBSUB_COMMANDS: + args = [ + self.global_keyprefix + str(arg) + for arg in args + ] + + return [command, *args] + + def parse_response(self, *args, **kwargs): + ret = super().parse_response(*args, **kwargs) + if ret is None: + return ret + if not isinstance(ret, list): + return ret + + message_type, *channels, message = ret + return [ + message_type, + *[channel[len(self.global_keyprefix):] for channel in channels], + message, + ] + + def execute_command(self, *args, **kwargs): + return super().execute_command(*self._prefix_args(args), **kwargs) + + +class QoS(RedisQoS): + """Redis cluster Ack Emulation. + + Redis doesn't support transaction, if keys are located on different slots/nodes. + We must ensure all keys related to transaction are stored on a single slot. + We can use hash tag to do that. + Then we can take the node holding the slot as a single Redis instance, and run transaction on that node. + + Because node.redis_connection(redis.client.Redis) is not override-able, global_prefix cannot + take effect in transaction. So we need to add prefix manually. + """ + + def restore_visible(self, start=0, num=10, interval=10): + self._vrestore_count += 1 + if (self._vrestore_count - 1) % interval: + return + with self.channel.conn_or_acquire() as client: + ceil = time() - self.visibility_timeout + try: + node = client.nodes_manager.get_node_from_slot(client.keyslot(self.unacked_mutex_key)) + with Mutex(node.redis_connection, self.unacked_mutex_key, + self.unacked_mutex_expire): + visible = client.zrevrangebyscore( + self.unacked_index_key, ceil, 0, + start=num and start, num=num, withscores=True) + for tag, score in visible or []: + self.restore_by_tag(tag, client) + except MutexHeld: + pass + + def restore_by_tag(self, tag, client=None, leftmost=False): + + def restore_transaction(pipe): + p = pipe.hget(self.channel.global_keyprefix + self.unacked_key, tag) + pipe.multi() + self._remove_from_indices(tag, pipe, key_prefix=self.channel.global_keyprefix) + if p: + M, EX, RK = loads(bytes_to_str(p)) + self.channel._do_restore_message(M, EX, RK, pipe, leftmost, key_prefix=self.channel.global_keyprefix) + + with self.channel.conn_or_acquire(client) as client: + if self.channel.hash_tag: + node = client.nodes_manager.get_node_from_slot(client.keyslot(self.unacked_key)) + node.redis_connection.transaction(restore_transaction, + self.channel.global_keyprefix + self.unacked_key) + else: + # Without transactions, problems may occur + p = client.hget(self.unacked_key, tag) + with client.pipeline() as pipe: + self._remove_from_indices(tag, pipe) + if p: + M, EX, RK = loads(bytes_to_str(p)) + self.channel._do_restore_message(M, EX, RK, pipe, leftmost) + pipe.execute() + + def _remove_from_indices(self, delivery_tag, pipe=None, key_prefix=''): + with self.pipe_or_acquire(pipe) as pipe: + return pipe.zrem(key_prefix + self.unacked_index_key, delivery_tag) \ + .hdel(key_prefix + self.unacked_key, delivery_tag) + + +class MultiChannelPoller(RedisMultiChannelPoller): + """Async I/O poller for Redis cluster transport. + + Add _chan_active_queues_to_conn to record queue to redis.Connection mapping + """ + + def __init__(self): + super().__init__() + # channel1 + # |-> redis.connection1(fd1 <-> node1) + # |-> redis.connection2(fd2 <-> node2) + # channel2 + # |-> redis.connection3(fd3 <-> node1) + # |-> redis.connection4(fd4 <-> node3) + + # (channel, queue) -> conn + self._chan_active_queues_to_conn = {} + + def close(self): + for fd in self._chan_to_sock.values(): + try: + self.poller.unregister(fd) + except (KeyError, ValueError): + pass + self._channels.clear() + self._fd_to_chan.clear() + self._chan_to_sock.clear() + self._chan_active_queues_to_conn.clear() + + def _register(self, channel, client, conn, type): + if (channel, client, conn, type) in self._chan_to_sock: + self._unregister(channel, client, conn, type) + if conn._sock is None: + # We closed the connection when exception occurred during `_brpop_read` + conn.connect() + + sock = conn._sock + self._fd_to_chan[sock.fileno()] = (channel, conn, type) + self._chan_to_sock[(channel, client, conn, type)] = sock + self.poller.register(sock, self.eventflags) + + def _unregister(self, channel, client, conn, type): + self.poller.unregister(self._chan_to_sock[(channel, client, conn, type)]) + + def _get_conns_for_channel(self, channel): + conns = set() + for queue in channel.active_queues: + if (channel, queue) not in self._chan_active_queues_to_conn: + slot = channel.client.keyslot(queue) + node = channel.client.nodes_manager.get_node_from_slot(slot, read_from_replicas=False) + # Different queues use different connections + conn = node.redis_connection.connection_pool.get_connection("_") + self._chan_active_queues_to_conn[(channel, queue)] = conn + conns.add(self._chan_active_queues_to_conn[(channel, queue)]) + return conns + + def _register_BRPOP(self, channel): + conns = self._get_conns_for_channel(channel) + + for conn in conns: + ident = (channel, channel.client, conn, 'BRPOP') + if conn._sock is None or ident not in self._chan_to_sock: + channel._in_poll = False + self._register(*ident) + if not channel._in_poll: + channel._brpop_start() + + def _register_LISTEN(self, channel): + conn = channel.subclient.connection + ident = (channel, channel.subclient, conn, 'LISTEN') + if conn._sock is None or ident not in self._chan_to_sock: + channel._in_listen = False + self._register(*ident) + if not channel._in_listen: + channel._subscribe() + + def on_readable(self, fileno): + chan, conn, type = self._fd_to_chan[fileno] + if chan.qos.can_consume(): + try: + chan.handlers[type](**{'conn': conn}) + except MovedError: + # When a key is moved, the connection previously used to access the key + # needs to be replaced with the new connection after the move. + # The connection will be rebuilt in the next loop. + self._unregister_connection(conn, fileno=fileno) + raise Empty() + + def _unregister_connection(self, redis_connection, fileno=None): + if not fileno and redis_connection._sock: + fileno = redis_connection._sock.fileno() + + self._fd_to_chan.pop(fileno, None) + for channel, client, conn, type in list(self._chan_to_sock.keys()): + if conn == redis_connection: + del self._chan_to_sock[(channel, client, conn, type)] + + for channel, queue in list(self._chan_active_queues_to_conn.keys()): + if self._chan_active_queues_to_conn[(channel, queue)] == redis_connection: + del self._chan_active_queues_to_conn[(channel, queue)] + try: + self.poller.unregister(redis_connection._sock) + except (KeyError, ValueError): + pass + + def handle_event(self, fileno, event): + if event & READ: + return self.on_readable(fileno), self + elif event & ERR: + chan, conn, type = self._fd_to_chan[fileno] + chan._poll_error(conn, type) + + +class Channel(RedisChannel): + """Redis Cluster Channel.""" + + QoS = QoS + + _client = None + _in_poll = False + _in_poll_connections = set() + _in_listen = False + + hash_tag = '' + unacked_key = 'unacked' + unacked_index_key = 'unacked_index' + unacked_mutex_key = 'unacked_mutex' + global_keyprefix = '' + + from_transport_options = ( + RedisChannel.from_transport_options + + ('hash_tag',) + ) + + def __init__(self, connection, *args, **kwargs): + VirtualBaseChannel.__init__(self, connection, *args, **kwargs) + if not self.ack_emulation: + self.QoS = virtual.QoS + self._registered = False + self._queue_cycle = cycle_by_name(self.queue_order_strategy)() + self.ResponseError = self._get_response_error() + self.active_fanout_queues = set() + self.auto_delete_queues = set() + self._fanout_to_queue = {} + self.handlers = {'BRPOP': self._brpop_read, 'LISTEN': self._receive} + + if self.fanout_prefix: + if isinstance(self.fanout_prefix, str): + self.keyprefix_fanout = self.fanout_prefix + else: + self.keyprefix_fanout = '' + + self.connection.cycle.add(self) + self._registered = True + self.connection_errors = self.connection.connection_errors + + if register_after_fork is not None: + register_after_fork(self, _after_fork_cleanup_channel) + + if not self.hash_tag: + self.priority_steps = [0] + else: + self.global_keyprefix = f'{self.hash_tag}{self.global_keyprefix}' + + self.Client = self._get_client() + + def _after_fork(self): + self._disconnect_pools() + + def _disconnect_pools(self): + client = self._client + if client is not None: + client.disconnect_connection_pools() + client.close() + self._client = None + self._in_poll_connections.clear() + + def _on_connection_disconnect(self, connection): + if self._in_poll is not None and connection in self._in_poll_connections: + self._in_poll = None + self._in_poll_connections.discard(connection) + if self._in_listen is connection: + self._in_listen = None + if self.connection and self.connection.cycle: + self.connection.cycle._on_connection_disconnect(connection) + + def _restore(self, message, leftmost=False): + if not self.ack_emulation: + return super()._restore(message) + tag = message.delivery_tag + + def restore_transaction(pipe): + P = pipe.hget(self.global_keyprefix + self.unacked_key, tag) + pipe.multi() + pipe.hdel(self.global_keyprefix + self.unacked_key, tag) + if P: + M, EX, RK = loads(bytes_to_str(P)) + self._do_restore_message(M, EX, RK, pipe, leftmost, key_prefix=self.global_keyprefix) + + with self.conn_or_acquire() as client: + if self.hash_tag: + node = client.nodes_manager.get_node_from_slot(client.keyslot(self.unacked_key)) + node.redis_connection.transaction(restore_transaction, self.global_keyprefix + self.unacked_key) + else: + # Without transactions, problems may occur + P = client.hget(self.unacked_key, tag) + with client.pipeline() as pipe: + pipe.hdel(self.unacked_key, tag) + if P: + M, EX, RK = loads(bytes_to_str(P)) + self._do_restore_message(M, EX, RK, pipe, leftmost) + pipe.execute() + + def _brpop_start(self, timeout=1): + queues = self._queue_cycle.consume(len(self.active_queues)) + if not queues: + return + pri_queues = [self._q_for_pri(queue, pri) for pri in self.priority_steps + for queue in queues] + self._in_poll = True + + node_to_keys = {} + for key in pri_queues: + node = self.client.nodes_manager.get_node_from_slot(self.client.keyslot(key)) + node_to_keys.setdefault(f'{node.host}:{node.port}', []).append(key) + + for chan, client, conn, cmd in self.connection.cycle._chan_to_sock: + expected = (self, self.client, 'BRPOP') + keys = node_to_keys.get(f'{conn.host}:{conn.port}') + + if keys and (chan, client, cmd) == expected: + command_args = ['BRPOP', *keys, timeout] + if self.global_keyprefix: + command_args = self.client._prefix_args(command_args) + conn.send_command(*command_args) + self._in_poll_connections.add(conn) + + def _brpop_read(self, **options): + conn = options.pop('conn', None) + try: + try: + dest__item = conn.read_response('BRPOP', **options) + if dest__item: + key, value = dest__item + key = key[len(self.global_keyprefix):] + dest__item = key, value + except self.connection_errors: + if conn is not None: + conn.disconnect() + # Remove the failed node from the startup nodes before we try + # to reinitialize the cluster + target_node = self.client.nodes_manager.startup_nodes.pop(f'{conn.host}:{conn.port}', None) + # Reset the cluster node's connection + target_node.redis_connection = None + self.client.nodes_manager.initialize() + raise + except MovedError: + # poller need to remove conn + self.client.nodes_manager.initialize() + raise + except (TryAgainError, AskError): + raise Empty() + + if dest__item: + dest, item = dest__item + dest = bytes_to_str(dest).rsplit(self.sep, 1)[0] + self._queue_cycle.rotate(dest) + self.connection._deliver(loads(bytes_to_str(item)), dest) + return True + else: + raise Empty() + finally: + self._in_poll_connections.discard(conn) + # To avoid inconsistencies between _in_poll and _in_poll_connections in abnormal situations, + # _in_poll is set to None after any connection being read. + self._in_poll = None + + def _receive(self, **kwargs): + super()._receive() + + def _poll_error(self, conn, type, **options): + if type == 'LISTEN': + self.subclient.parse_response() + else: + conn.read_response(type) + + def close(self): + self._closing = True + if self._in_poll or len(self._in_poll_connections) != 0: + try: + for conn in self._in_poll_connections.copy(): + self._brpop_read(**{'conn': conn}) + except Empty: + pass + if not self.closed: + self.connection.cycle.discard(self) + client = self.__dict__.get('client') + if client is not None: + for queue in self._fanout_queues: + if queue in self.auto_delete_queues: + self.queue_delete(queue, client=client) + self._disconnect_pools() + self._close_clients() + VirtualBaseChannel.close(self) + + def _close_clients(self): + for attr in 'client', 'subclient': + try: + client = self.__dict__[attr] + if attr == 'client': + client.disconnect_connection_pools() + client.close() + if attr == 'subclient': + connection, client.connection = client.connection, None + # The Redis server will automatically detect the disconnection of the client connection + # and clean up all subscription states of the client. + connection.disconnect() + except (KeyError, AttributeError, self.ResponseError): + pass + + def _connparams(self, asynchronous=False): + conn_params = super()._connparams(asynchronous=asynchronous) + # connection_class and db is not supported in redis.client.Redis + # connection_pool_class is only effective when the url parameter is not empty + conn_params.pop('db', None) + conn_params.pop('connection_class', None) + connection_cls = redis.Connection + if self.connection.client.ssl: + connection_cls = redis.SSLConnection + + if asynchronous: + channel = self + + class ManagedConnection(connection_cls): + def disconnect(self, *args): + super().disconnect(*args) + if channel._registered: + channel._on_connection_disconnect(self) + + class ManagedConnectionPool(redis.ConnectionPool): + def __init__(self, *args, **kwargs): + kwargs['connection_class'] = ManagedConnection + super().__init__(*args, **kwargs) + + conn_params['connection_pool_class'] = ManagedConnectionPool + + conn_params['url'] = f'redis://{conn_params["host"]}:{conn_params["port"]}' + return conn_params + + def _create_client(self, asynchronous=False): + if self._client is None: + conn_params = self._connparams(asynchronous=asynchronous) + self._client = self.Client(**conn_params) + return self._client + + def _get_client(self): + if redis.VERSION < (4, 1, 0): + raise VersionMismatch( + 'Redis cluster transport requires redis-py versions 4.1.0 or later. ' + 'You have {0.__version__}'.format(redis)) + + if self.global_keyprefix: + return functools.partial( + PrefixedStrictRedis, + global_keyprefix=self.global_keyprefix, + ) + + return redis.cluster.RedisCluster + + @cached_property + def subclient(self): + client = self._create_client(asynchronous=True) + pubsub_client = client.pubsub() + # Init connection and connection pool in client. + # `ClusterPubSub` initializes the connection and connection pool + # when `execute_command` is called for the first time + pubsub_client.ping() + return pubsub_client + + def _do_restore_message(self, payload, exchange, routing_key, + pipe, leftmost=False, key_prefix=''): + try: + try: + payload['headers']['redelivered'] = True + payload['properties']['delivery_info']['redelivered'] = True + except KeyError: + pass + for queue in self._lookup(exchange, routing_key): + pri = self._get_message_priority(payload, reverse=False) + + (pipe.lpush if leftmost else pipe.rpush)( + key_prefix + self._q_for_pri(queue, pri), dumps(payload), + ) + except Exception: + crit('Could not restore message: %r', payload, exc_info=True) + + +class Transport(RedisTransport): + """Redis Cluster Transport.""" + + Channel = Channel + + driver_type = 'rediscluster' + driver_name = 'rediscluster' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cycle = MultiChannelPoller() diff --git a/t/unit/transport/test_rediscluster.py b/t/unit/transport/test_rediscluster.py new file mode 100644 index 000000000..a6887e28f --- /dev/null +++ b/t/unit/transport/test_rediscluster.py @@ -0,0 +1,1239 @@ +from __future__ import annotations + +import base64 +import copy +import socket +from collections import defaultdict +from itertools import count +from queue import Empty +from queue import Queue as _Queue +from unittest.mock import ANY, Mock, call, patch + +import pytest +from redis.exceptions import MovedError, TryAgainError + +from kombu import Connection, Consumer, Exchange, Producer, Queue +from kombu.exceptions import VersionMismatch +from kombu.transport import redis as _redis +from kombu.transport import rediscluster as redis +from kombu.transport import virtual +from kombu.utils import eventio +from kombu.utils.json import dumps + + +class _poll(eventio._select): + + def register(self, fd, flags): + if flags & eventio.READ: + self._rfd.add(fd) + + def poll(self, timeout): + events = [] + for fd in self._rfd: + if fd.data: + events.append((fd.fileno(), eventio.READ)) + return events + + +_redis.poll = _poll + +pytest.importorskip('redis') + + +class RedisCommandBase: + data = defaultdict(dict) + sets = defaultdict(set) + hashes = defaultdict(dict) + queues = {} + + def __init__(self, *args, **kwargs): + self.connection_pool = RedisConnectionPool() + + def sadd(self, key, value): + if key not in self.data: + self.data[key] = set() + self.data[key].add(value) + + def smembers(self, key): + return self.data.get(key) + + def hset(self, key, k, v): + self.hashes[key][k] = v + + def hget(self, key, k): + return self.hashes[key].get(k) + + def hdel(self, key, k): + self.hashes[key].pop(k, None) + + def ping(self): + return True + + def zrevrangebyscore(self): + return [] + + def llen(self, key): + try: + return self.queues[key].qsize() + except KeyError: + return 0 + + def lpush(self, key, value): + self.queues[key].put_nowait(value) + + def set(self, key, value, **kwargs): + self.data[key] = value + return True + + def get(self, key): + return self.data[key] + + def delete(self, key): + self.data.pop(key, None) + + def rpop(self, key): + try: + return self.queues[key].get_nowait() + except (KeyError, Empty): + pass + + def brpop(self, keys, timeout=None): + for key in keys: + try: + item = self.queues[key].get_nowait() + except Empty: + pass + else: + return key, item + + def zadd(self, key, *args): + (mapping,) = args + for item in mapping: + self.sets[key].add(item) + + def zrem(self, key, *args): + self.sets.pop(key, None) + + def srem(self, key, *args): + self.sets.pop(key, None) + + def pipeline(self, *args, **kwargs): + pass + + def transaction(self, func, *watches, **kwargs): + with self.pipeline() as pipe: + if watches: + pipe.watch(*watches) + func(pipe) + exec_value = pipe.execute() + return exec_value + + +class RedisPipelineBase: + def __init__(self, client): + self.client = client + self.command_stack = [] + + def __enter__(self): + return self + + def __exit__(self, *args) -> None: + pass + + def __getattr__(self, key): + if key not in self.__dict__: + def _add(*args, **kwargs): + self.command_stack.append((getattr(self.client, key), args, kwargs)) + return self + + return _add + return self.__dict__[key] + + def watch(self, key): + pass + + def unwatch(self): + pass + + def multi(self): + pass + + def execute(self): + stack = list(self.command_stack) + self.command_stack[:] = [] + return [fun(*args, **kwargs) for fun, args, kwargs in stack] + + +class RedisConnection: + class _socket: + filenos = count(30) + + def __init__(self, *args): + self._fileno = next(self.filenos) + self.data = [] + + def fileno(self): + return self._fileno + + def __init__(self, host="localhost", port=6379): + self._sock = self._socket() + self.host = host + self.port = port + + def disconnect(self): + pass + + def send_command(self, cmd, *args, **kwargs): + self._sock.data.append((cmd, args)) + + def read_response(self, *args, **kwargs): + try: + cmd, queues = self._sock.data.pop() + except IndexError: + raise Empty() + queues = list(queues) + self._sock.data = [] + if cmd == 'BRPOP': + queues.pop() + item = None + for key in queues: + try: + res = RedisCommandBase.queues[key].get_nowait() + except Empty: + pass + else: + item = key, res + if item: + return item + raise Empty() + + +class RedisConnectionPool: + def __init__(self, *args, **kargs): + self._available_connections = [] + self._in_use_connections = set() + + def get_connection(self, command): + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + return connection + + def make_connection(self): + return RedisConnection() + + +class ClusterPipeline(RedisPipelineBase): + pass + + +class RedisPipeline(RedisPipelineBase): + pass + + +class Redis(RedisCommandBase): + + def pipeline(self): + return RedisPipeline(self) + + +class RedisCluster(RedisCommandBase): + db_0 = {} + + def __init__(self, *args, **kwargs): + self.nodes_manager = NodesManager() + + def disconnect_connection_pools(self): + pass + + def close(self): + pass + + def pipeline(self): + return ClusterPipeline(self) + + def keyslot(self, key): + return 0 + + +DEFAULT_PORT = 6379 +DEFAULT_HOST = 'localhost' + + +class NodesManager: + def __init__(self): + node = ClusterNode() + self.nodes_cache = {0: node} + self.startup_nodes = {f'{node.host}:{node.port}': node} + + def get_node_from_slot(self, slot, **kwargs): + return self.nodes_cache.get(slot) + + def initialize(self): + pass + + +class ClusterNode: + def __init__(self, host="localhost", port=6379): + self.host = host + self.port = port + self.name = f'{self.host}:{self.port}' + self.redis_connection = Redis() + + +class Channel(redis.Channel): + + def _get_client(self): + return RedisCluster + + def _new_queue(self, queue, **kwargs): + for pri in self.priority_steps: + self.client.queues[self._q_for_pri(queue, pri)] = _Queue() + + +class Transport(redis.Transport): + Channel = Channel + connection_errors = (KeyError,) + channel_errors = (IndexError,) + + +class test_Channel: + + def setup_method(self): + self.connection = self.create_connection() + self.channel = self.connection.default_channel + + def create_connection(self, **kwargs): + kwargs.setdefault('transport_options', {'fanout_patterns': True}) + return Connection(transport=Transport, **kwargs) + + def test_init(self): + hash_tag = None + channel = self.create_connection().channel() + assert channel.priority_steps == [0] + assert channel._registered is True + assert channel.global_keyprefix == '' + + hash_tag = '{tag}' + channel = self.create_connection(transport_options={'hash_tag': hash_tag}).channel() + assert channel.priority_steps == [0, 3, 6, 9] + assert channel._registered is True + assert channel.global_keyprefix == '{tag}' + assert channel.unacked_mutex_key == 'unacked_mutex' + assert channel.unacked_index_key == 'unacked_index' + assert channel.unacked_key == 'unacked' + assert channel.keyprefix_fanout == '/{db}.' + assert channel.keyprefix_queue == '_kombu.binding.%s' + + global_keyprefix = 'foo' + channel = self.create_connection(transport_options={'global_keyprefix': global_keyprefix}).channel() + assert channel.priority_steps == [0] + assert channel._registered is True + assert channel.global_keyprefix == global_keyprefix + + def test_after_fork(self): + channel = self.create_connection().channel() + channel._after_fork() + assert channel._client is None + + def test_sep_transport_option(self): + with Connection(transport=Transport, transport_options={ + 'sep': ':', + }) as conn: + key = conn.default_channel.keyprefix_queue % 'celery' + conn.default_channel.client.sadd(key, 'celery::celery') + + assert conn.default_channel.sep == ':' + assert conn.default_channel.get_table('celery') == [ + ('celery', '', 'celery'), + ] + + def test_ack_emulation_transport_option(self): + conn = Connection(transport=Transport, transport_options={ + 'ack_emulation': False, + }) + + chan = conn.channel() + assert not chan.ack_emulation + assert chan.QoS == virtual.QoS + + def test_do_restore_message(self): + client = Mock(name='client') + pl1 = {'body': 'BODY'} + spl1 = dumps(pl1) + lookup = self.channel._lookup = Mock(name='_lookup') + lookup.return_value = {'george', 'elaine'} + self.channel._do_restore_message( + pl1, 'ex', 'rkey', client, + ) + client.rpush.assert_has_calls([ + call('george', spl1), call('elaine', spl1), + ], any_order=True) + + client = Mock(name='client') + pl2 = {'body': 'BODY2', 'headers': {'x-funny': 1}} + headers_after = dict(pl2['headers'], redelivered=True) + spl2 = dumps(dict(pl2, headers=headers_after)) + self.channel._do_restore_message( + pl2, 'ex', 'rkey', client, + ) + client.rpush.assert_any_call('george', spl2) + client.rpush.assert_any_call('elaine', spl2) + + client.rpush.side_effect = KeyError() + with patch('kombu.transport.rediscluster.crit') as crit: + self.channel._do_restore_message( + pl2, 'ex', 'rkey', client, + ) + crit.assert_called() + + def test_do_restore_message_celery(self): + payload = { + "body": base64.b64encode(dumps([ + [], + {}, + { + "callbacks": None, + "errbacks": None, + "chain": None, + "chord": None, + }, + ]).encode()).decode(), + "content-encoding": "utf-8", + "content-type": "application/json", + "headers": { + "lang": "py", + "task": "common.tasks.test_task", + "id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "shadow": None, + "eta": None, + "expires": None, + "group": None, + "group_index": None, + "retries": 0, + "timelimit": [None, None], + "root_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "parent_id": None, + "argsrepr": "()", + "kwargsrepr": "{}", + "origin": "gen3437@Desktop", + "ignore_result": False, + }, + "properties": { + "correlation_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "reply_to": "512f2489-ca40-3585-bc10-9b801a981782", + "delivery_mode": 2, + "delivery_info": { + "exchange": "", + "routing_key": "celery", + }, + "priority": 3, + "body_encoding": "base64", + "delivery_tag": "badb725e-9c3e-45be-b0a4-07e44630519f", + }, + } + result_payload = copy.deepcopy(payload) + result_payload['headers']['redelivered'] = True + result_payload['properties']['delivery_info']['redelivered'] = True + queue = 'celery' + + client = Mock(name='client') + lookup = self.channel._lookup = Mock(name='_lookup') + lookup.return_value = [queue] + + self.channel._do_restore_message( + payload, 'exchange', 'routing_key', client, + ) + + client.rpush.assert_called_with(self.channel._q_for_pri(queue, 3), + dumps(result_payload)) + + def test_restore_messages(self): + message = Mock(name='message') + message.delivery_tag = mock_tag = 'tag' + channel = self.create_connection().channel() + _do_restore_message = channel._do_restore_message = Mock() + channel.client.hset('unacked', mock_tag, message) + with patch('kombu.transport.rediscluster.loads') as loads: + loads.return_value = 'M', 'EX', 'RK' + channel._restore(message) + _do_restore_message.assert_called_with('M', 'EX', 'RK', ANY, False) + + def test_qos_restore_visible(self): + channel = self.create_connection().channel() + client = channel.client + zrevrangebyscore = client.zrevrangebyscore = Mock() + zrevrangebyscore.return_value = [ + (1, 10), + (2, 20), + (3, 30), + ] + qos = redis.QoS(channel) + restore = qos.restore_by_tag = Mock(name='restore_by_tag') + qos._vrestore_count = 1 + qos.restore_visible() + zrevrangebyscore.assert_not_called() + assert qos._vrestore_count == 2 + + qos._vrestore_count = 0 + qos.restore_visible() + restore.assert_has_calls([ + call(1, client), call(2, client), call(3, client), + ]) + assert qos._vrestore_count == 1 + + qos._vrestore_count = 0 + restore.reset_mock() + zrevrangebyscore.return_value = [] + qos.restore_visible() + restore.assert_not_called() + assert qos._vrestore_count == 1 + + qos._vrestore_count = 0 + set = client.set = Mock() + set.side_effect = redis.MutexHeld() + qos.restore_visible() + + def test_restore_by_tag(self): + channel = self.create_connection(transport_options={'hash_tag': '{tag}'}).channel() + qos = redis.QoS(channel) + _do_restore_message = channel._do_restore_message = Mock() + with patch('kombu.transport.rediscluster.loads') as loads: + loads.return_value = 'M', 'EX', 'RK' + qos.restore_by_tag('test', channel.client) + _do_restore_message.assert_called_with('M', 'EX', 'RK', ANY, False, key_prefix='{tag}') + + def test_restore(self): + channel = self.create_connection(transport_options={'hash_tag': '{tag}'}).channel() + message = Mock() + _do_restore_message = channel._do_restore_message = Mock() + with patch('kombu.transport.rediscluster.loads') as loads: + loads.return_value = 'M', 'EX', 'RK' + channel._restore(message) + _do_restore_message.assert_called_with('M', 'EX', 'RK', ANY, False, key_prefix='{tag}') + + def test_basic_consume_when_fanout_queue(self): + self.channel.exchange_declare(exchange='txconfan', type='fanout') + self.channel.queue_declare(queue='txconfanq') + self.channel.queue_bind(queue='txconfanq', exchange='txconfan') + + assert 'txconfanq' in self.channel._fanout_queues + self.channel.basic_consume('txconfanq', False, None, 1) + assert 'txconfanq' in self.channel.active_fanout_queues + assert self.channel._fanout_to_queue.get('txconfan') == 'txconfanq' + + @patch('redis.cluster.RedisCluster.execute_command') + @patch('redis.cluster.NodesManager.initialize') + def test_get_prefixed_client(self, mock_initialize, mock_execute_command): + self.channel.global_keyprefix = "test_" + PrefixedRedis = redis.Channel._get_client(self.channel) + assert isinstance(PrefixedRedis(startup_nodes=[ClusterNode()]), redis.PrefixedStrictRedis) + + @patch("redis.cluster.RedisCluster.keyslot") + @patch("redis.cluster.RedisCluster.execute_command") + @patch("redis.cluster.NodesManager.initialize") + def test_global_keyprefix(self, mock_initialize, mock_execute_command, mock_keyslot): + with Connection(transport=Transport) as conn: + client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()]) + + channel = conn.channel() + channel._create_client = Mock() + channel._create_client.return_value = client + + body = {'hello': 'world'} + channel._put_fanout('exchange', body, '') + mock_execute_command.assert_called_with( + 'PUBLISH', + 'foo_/{db}.exchange', + dumps(body) + ) + + client.keyslot('a') + mock_keyslot.assert_called_with('foo_a') + + @patch("redis.cluster.RedisCluster.execute_command") + @patch("redis.cluster.NodesManager.initialize") + def test_global_keyprefix_queue_bind(self, mock_initialize, mock_execute_command): + with Connection(transport=Transport) as conn: + client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()]) + + channel = conn.channel() + channel._create_client = Mock() + channel._create_client.return_value = client + + channel._queue_bind('default', '', None, 'queue') + mock_execute_command.assert_called_with( + 'SADD', + 'foo__kombu.binding.default', + '\x06\x16\x06\x16queue' + ) + + @patch("redis.cluster.RedisCluster.execute_command") + @patch('redis.cluster.ClusterPubSub.execute_command') + @patch('redis.cluster.NodesManager.initialize') + def test_global_keyprefix_pubsub(self, mock_initialize, mock_pubsub, mock_execute_command): + with Connection(transport=Transport) as conn: + client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()]) + + channel = conn.channel() + channel.global_keyprefix = 'foo_' + channel._create_client = Mock() + channel._create_client.return_value = client + channel.subclient.connection = Mock() + channel._fanout_queues.update(a=('a', '')) + channel.active_fanout_queues.add('a') + + channel._subscribe() + mock_pubsub.assert_called_with( + 'PSUBSCRIBE', + 'foo_/{db}.a', + ) + + @patch("redis.cluster.RedisCluster.execute_command") + @patch('redis.cluster.NodesManager.initialize') + def test_get_client(self, mock_initialize, mock_execute_command): + import redis as R + KombuRedis = redis.Channel._get_client(self.channel) + assert isinstance(KombuRedis(startup_nodes=[ClusterNode()]), R.cluster.RedisCluster) + + Rv = getattr(R, 'VERSION', None) + try: + R.VERSION = (2, 4, 0) + with pytest.raises(VersionMismatch): + redis.Channel._get_client(self.channel) + finally: + if Rv is not None: + R.VERSION = Rv + + @patch("redis.cluster.RedisCluster.execute_command") + @patch('redis.cluster.NodesManager.initialize') + def test_prefixed_pipeline(self, mock_initialize, mock_execute_command): + client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()]) + pipeline = client.pipeline() + send_cluster_commands = pipeline.send_cluster_commands = Mock() + pipeline.set("a", "1") + pipeline.set("b", "2") + pipeline.execute() + assert send_cluster_commands.call_args[0][0][0].args == ('SET', 'foo_a', '1') + assert send_cluster_commands.call_args[0][0][1].args == ('SET', 'foo_b', '2') + + def test_brpop_read_raises(self): + channel = self.create_connection().channel() + conn = RedisConnection() + read_response = conn.read_response = Mock() + initialize = channel.client.nodes_manager.initialize = Mock() + read_response.side_effect = KeyError('foo') + + with pytest.raises(KeyError): + channel._brpop_read(conn=conn) + + initialize.assert_called_with() + assert channel.client.nodes_manager.startup_nodes == {} + + read_response.side_effect = TryAgainError('foo') + + with pytest.raises(Empty): + channel._brpop_read(conn=conn) + + read_response.side_effect = MovedError('1 0.0.0.0:0') + + initialize.reset_mock() + with pytest.raises(MovedError): + channel._brpop_read(conn=conn) + initialize.assert_called_with() + + def test_brpop_read_gives_None(self): + conn = RedisConnection() + read_response = conn.read_response = Mock() + read_response.return_value = None + + with pytest.raises(redis.Empty): + self.channel._brpop_read(conn=conn) + + def test_poll_error(self): + channel = self.create_connection().channel() + conn = RedisConnection() + with pytest.raises(Empty): + channel._poll_error(conn, 'BRPOP') + + conn = RedisConnection() + conn._sock.data = [('BRPOP', ('test_Redis',))] + with pytest.raises(Empty): + channel._poll_error(conn, 'BRPOP') + assert conn._sock.data == [] + + def test_redis_on_disconnect_channel_only_if_was_registered(self): + """Test should check if the _on_disconnect method is called only + if the channel was registered into the poller.""" + # given: mock pool and client + pool = Mock(name='pool') + client = Mock( + name='client', + ping=Mock(return_value=True) + ) + + # create RedisConnectionMock class + # for the possibility to run disconnect method + class RedisConnectionMock: + def disconnect(self, *args): + pass + + # override Channel method with given mocks + class XChannel(Channel): + connection_class = RedisConnectionMock + + def __init__(self, *args, **kwargs): + self._pool = pool + # counter to check if the method was called + self.on_disconnect_count = 0 + super().__init__(*args, **kwargs) + + def _get_client(self): + return lambda *_, **__: client + + def _on_connection_disconnect(self, connection): + # increment the counter when the method is called + self.on_disconnect_count += 1 + + # create the channel + chan = XChannel(Mock( + _used_channel_ids=[], + channel_max=1, + channels=[], + client=Mock( + transport_options={}, + hostname="127.0.0.1", + virtual_host=None))) + # create the _connparams with overridden connection_class + connparams = chan._connparams(asynchronous=True) + # create redis.Connection + assert connparams['connection_pool_class'].__name__ == 'ManagedConnectionPool' + redis_connection_pool = connparams['connection_pool_class']() + with patch('redis.connection.AbstractConnection.connect'): + redis_connection = redis_connection_pool.get_connection('-') + # the connection was added to the cycle + chan.connection.cycle.add.assert_called_once() + # the channel was registered + assert chan._registered + # than disconnect the Redis connection + redis_connection.disconnect() + # the on_disconnect counter should be incremented + assert chan.on_disconnect_count == 1 + + +class test_Redis: + + def setup_method(self): + self.connection = Connection(transport=Transport) + self.exchange = Exchange('test_Redis', type='direct') + self.queue = Queue('test_Redis', self.exchange, 'test_Redis') + + def teardown_method(self): + self.connection.close() + + def test_publish_get(self): + channel = self.connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + self.queue(channel).declare() + + producer.publish({'hello': 'world'}) + + assert self.queue(channel).get().payload == {'hello': 'world'} + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + + def test_publish_consume(self): + redis.poll = _poll + connection = Connection(transport=Transport) + channel = connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + consumer = Consumer(channel, queues=[self.queue]) + + producer.publish({'hello2': 'world2'}) + _received = [] + + def callback(message_data, message): + _received.append(message_data) + message.ack() + + consumer.register_callback(callback) + consumer.consume() + + assert channel in channel.connection.cycle._channels + try: + connection.drain_events(timeout=1) + assert _received + with pytest.raises(socket.timeout): + connection.drain_events(timeout=0.01) + finally: + channel.close() + + def test_purge(self): + channel = self.connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + self.queue(channel).declare() + + for i in range(10): + producer.publish({'hello': f'world-{i}'}) + + assert channel._size('test_Redis') == 10 + assert self.queue(channel).purge() == 10 + channel.close() + + def test_db_values(self): + Connection(virtual_host=1, + transport=Transport).channel() + + Connection(virtual_host='1', + transport=Transport).channel() + + Connection(virtual_host='/1', + transport=Transport).channel() + + with pytest.raises(Exception): + Connection('redis:///foo').channel() + + def test_db_port(self): + c1 = Connection(port=None, transport=Transport).channel() + c1.close() + + c2 = Connection(port=9999, transport=Transport).channel() + c2.close() + + def test_close_poller_not_active(self): + c = Connection(transport=Transport).channel() + cycle = c.connection.cycle + c.close() + assert c not in cycle._channels + + def test_close_ResponseError(self): + c = Connection(transport=Transport).channel() + c.client.bgsave_raises_ResponseError = True + c.close() + + def test_close_in_poll(self): + c = Connection(transport=Transport).channel() + conn = RedisConnection() + conn._sock.data = [('BRPOP', ('test_Redis',))] + c._in_poll_connections.add(conn) + c._in_poll = True + c.close() + assert conn._sock.data == [] + + def test_get__Empty(self): + channel = self.connection.channel() + with pytest.raises(Empty): + channel._get('does-not-exist') + channel.close() + + +class test_MultiChannelPoller: + + def setup_method(self): + self.Poller = redis.MultiChannelPoller + self.connection = Connection(transport=redis.Transport) + + def test_init(self): + p = self.Poller() + assert p._chan_active_queues_to_conn == {} + + def test_on_poll_start(self): + p = self.Poller() + p._channels = [] + p.on_poll_start() + p._register_BRPOP = Mock(name='_register_BRPOP') + p._register_LISTEN = Mock(name='_register_LISTEN') + + chan1 = Mock(name='chan1') + p._channels = [chan1] + chan1.active_queues = [] + chan1.active_fanout_queues = [] + p.on_poll_start() + + chan1.active_queues = ['q1'] + chan1.active_fanout_queues = ['q2'] + chan1.qos.can_consume.return_value = False + + p.on_poll_start() + p._register_LISTEN.assert_called_with(chan1) + p._register_BRPOP.assert_not_called() + + chan1.qos.can_consume.return_value = True + p._register_LISTEN.reset_mock() + p.on_poll_start() + + p._register_BRPOP.assert_called_with(chan1) + p._register_LISTEN.assert_called_with(chan1) + + def test_on_poll_init(self): + p = self.Poller() + chan1 = Mock(name='chan1') + p._channels = [] + poller = Mock(name='poller') + p.on_poll_init(poller) + assert p.poller is poller + + p._channels = [chan1] + p.on_poll_init(poller) + chan1.qos.restore_visible.assert_called_with( + num=chan1.unacked_restore_limit, + ) + + def test_handle_event(self): + p = self.Poller() + chan = Mock(name='chan') + conn = Mock(name='conn') + p._fd_to_chan[13] = chan, conn, 'BRPOP' + chan.handlers = {'BRPOP': Mock(name='BRPOP')} + + chan.qos.can_consume.return_value = False + p.handle_event(13, redis.READ) + chan.handlers['BRPOP'].assert_not_called() + + chan.qos.can_consume.return_value = True + p.handle_event(13, redis.READ) + chan.handlers['BRPOP'].assert_called_with(conn=conn) + + p.handle_event(13, redis.ERR) + chan._poll_error.assert_called_with(conn, 'BRPOP') + + p.handle_event(13, ~(redis.READ | redis.ERR)) + + def test_fds(self): + p = self.Poller() + p._fd_to_chan = {1: 2} + assert p.fds == p._fd_to_chan + + def test_close_unregisters_fds(self): + p = self.Poller() + poller = p.poller = Mock() + p._chan_to_sock.update({1: 1, 2: 2, 3: 3}) + + p.close() + + assert poller.unregister.call_count == 3 + u_args = poller.unregister.call_args_list + + assert sorted(u_args) == [ + ((1,), {}), + ((2,), {}), + ((3,), {}), + ] + + def test_close_when_unregister_raises_KeyError(self): + p = self.Poller() + p.poller = Mock() + p._chan_to_sock.update({1: 1}) + p.poller.unregister.side_effect = KeyError(1) + p.close() + + def test_close_resets_state(self): + p = self.Poller() + p.poller = Mock() + p._channels = Mock() + p._fd_to_chan = Mock() + p._chan_to_sock = Mock() + p._chan_active_queues_to_conn = Mock() + + p._chan_to_sock.itervalues.return_value = [] + p._chan_to_sock.values.return_value = [] + + p.close() + p._channels.clear.assert_called_with() + p._fd_to_chan.clear.assert_called_with() + p._chan_to_sock.clear.assert_called_with() + p._chan_active_queues_to_conn.clear.assert_called_with() + + def test_register_when_registered_reregisters(self): + p = self.Poller() + p.poller = Mock() + channel, client, conn, type = Mock(), Mock(), Mock(), Mock() + sock = conn._sock = Mock() + sock.fileno.return_value = 10 + + p._chan_to_sock = {(channel, client, conn, type): 6} + p._register(channel, client, conn, type) + p.poller.unregister.assert_called_with(6) + assert p._fd_to_chan[10] == (channel, conn, type) + assert p._chan_to_sock[(channel, client, conn, type)] == sock + p.poller.register.assert_called_with(sock, p.eventflags) + + conn._sock = None + + def after_connected(): + conn._sock = Mock() + + conn.connect.side_effect = after_connected + p._register(channel, client, conn, type) + conn.connect.assert_called_with() + + def test_get_conns_for_channel(self): + p = self.Poller() + channel = Mock() + channel.active_queues = ['queue'] + p._chan_active_queues_to_conn = {} + conns = p._get_conns_for_channel(channel) + assert p._chan_active_queues_to_conn[(channel, 'queue')] == conns.pop() + + def test_register_BRPOP(self): + p = self.Poller() + conn = Mock() + conn._sock = None + get_conns_for_channel = p._get_conns_for_channel = Mock() + get_conns_for_channel.return_value = [conn] + + channel = Mock() + channel.active_queues = [] + p._register = Mock() + + channel._in_poll = False + p._register_BRPOP(channel) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 + + conn._sock = Mock() + p._chan_to_sock[(channel, channel.client, conn, 'BRPOP')] = True + channel._in_poll = True + p._register_BRPOP(channel) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 + + def test_register_LISTEN(self): + p = self.Poller() + conn = Mock() + conn._sock = None + get_conns_for_channel = p.get_conns_for_channel = Mock() + get_conns_for_channel.return_value = [conn] + + channel = Mock() + conn._sock = None + channel._in_listen = False + p._register = Mock() + + p._register_LISTEN(channel) + p._register.assert_called_with(channel, channel.subclient, channel.subclient.connection, 'LISTEN') + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 + + channel._in_listen = True + p._chan_to_sock[(channel, channel.subclient, channel.subclient.connection, 'LISTEN')] = 3 + channel.subclient.connection._sock = Mock() + p._register_LISTEN(channel) + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 + + def test_on_readable(self): + p = self.Poller() + channel, conn, conn2, _brpop_read, _receive = Mock(), Mock(), Mock(), Mock(), Mock() + channel.handlers = {'BRPOP': _brpop_read, 'LISTEN': _receive} + p._fd_to_chan = {0: (channel, conn, 'BRPOP'), 1: (channel, conn2, 'BRPOP')} + p._chan_to_sock = {(channel, channel.client, conn, 'BRPOP'): 0} + + p.on_readable(0) + _brpop_read.assert_called_with(conn=conn) + + _brpop_read.side_effect = MovedError('1 0.0.0.0:0') + conn._sock.fileno.return_value = 0 + conn2._sock.fileno.return_value = 1 + with pytest.raises(Empty): + p.on_readable(0) + assert p._fd_to_chan == {1: (channel, conn2, 'BRPOP')} + assert p._chan_to_sock == {} + + p._fd_to_chan = {0: (channel, conn, 'LISTEN')} + p.on_readable(0) + _receive.assert_called_with(conn=conn) + + def test_on_readable_when_moved(self): + p = self.Poller() + channel, conn, _brpop_read, _receive = Mock(), Mock(), Mock(), Mock() + channel.handlers = {'BRPOP': _brpop_read, 'LISTEN': _receive} + sock = conn._sock = Mock() + sock.fileno.return_value = 0 + _get_conns_for_channel = p._get_conns_for_channel = Mock() + _get_conns_for_channel.return_value = [conn] + + p._register_BRPOP(channel) + assert p._fd_to_chan == {0: (channel, conn, 'BRPOP')} + assert p._chan_to_sock == {(channel, channel.client, conn, 'BRPOP'): sock} + + _brpop_read.side_effect = MovedError('1 0.0.0.0:0') + poller_unregister = p.poller.unregister = Mock() + with pytest.raises(Empty): + p.on_readable(0) + assert p._fd_to_chan == {} + assert p._chan_to_sock == {} + poller_unregister.assert_called_with(sock) + + def create_get(self, events=None, queues=None, fanouts=None): + _pr = [] if events is None else events + _aq = [] if queues is None else queues + _af = [] if fanouts is None else fanouts + p = self.Poller() + p.poller = Mock() + p.poller.poll.return_value = _pr + + p._register_BRPOP = Mock() + p._register_LISTEN = Mock() + + channel = Mock() + p._channels = [channel] + channel.active_queues = _aq + channel.active_fanout_queues = _af + + return p, channel + + def test_get_no_actions(self): + p, channel = self.create_get() + + with pytest.raises(redis.Empty): + p.get(Mock()) + + def test_qos_reject(self): + p, channel = self.create_get() + qos = redis.QoS(channel) + qos._remove_from_indices = Mock(name='_remove_from_indices') + qos.reject(1234) + qos._remove_from_indices.assert_called_with(1234) + + def test_qos_requeue(self): + p, channel = self.create_get() + qos = redis.QoS(channel) + qos.restore_by_tag = Mock(name='restore_by_tag') + qos.reject(1234, True) + qos.restore_by_tag.assert_called_with(1234, leftmost=True) + + def test_get_brpop_qos_allow(self): + p, channel = self.create_get(queues=['a_queue']) + channel.qos.can_consume.return_value = True + + with pytest.raises(redis.Empty): + p.get(Mock()) + + p._register_BRPOP.assert_called_with(channel) + + def test_get_brpop_qos_disallow(self): + p, channel = self.create_get(queues=['a_queue']) + channel.qos.can_consume.return_value = False + + with pytest.raises(redis.Empty): + p.get(Mock()) + + p._register_BRPOP.assert_not_called() + + def test_get_listen(self): + p, channel = self.create_get(fanouts=['f_queue']) + + with pytest.raises(redis.Empty): + p.get(Mock()) + + p._register_LISTEN.assert_called_with(channel) + + def test_get_receives_ERR(self): + p, channel = self.create_get(events=[(1, eventio.ERR)]) + conn = Mock() + p._fd_to_chan[1] = (channel, conn, 'BRPOP') + + with pytest.raises(redis.Empty): + p.get(Mock()) + + channel._poll_error.assert_called_with(conn, 'BRPOP') + + def test_get_receives_multiple(self): + p, channel = self.create_get(events=[(1, eventio.ERR), + (1, eventio.ERR)]) + conn = Mock() + p._fd_to_chan[1] = (channel, conn, 'BRPOP') + + with pytest.raises(redis.Empty): + p.get(Mock()) + + channel._poll_error.assert_called_with(conn, 'BRPOP') + + +class test_Mutex: + + def test_mutex(self): + client = Redis() + set = client.set = Mock() + + # Won + set.return_value = True + held = False + with redis.Mutex(client, 'foo1', 100): + held = True + assert held + + # Did not win + set.return_value = False + held = False + with pytest.raises(redis.MutexHeld): + with redis.Mutex(client, 'foo1', 100): + held = True + assert not held + + +class test_GlobalKeyPrefixMixin: + global_keyprefix = "prefix_" + hash_tag = "{tag}" + mixin = redis.GlobalKeyPrefixMixin() + mixin.global_keyprefix = global_keyprefix + mixin.hash_tag = hash_tag + + def test_prefix_simple_args(self): + for command in self.mixin.PREFIXED_SIMPLE_COMMANDS: + prefixed_args = self.mixin._prefix_args([command, "fake_key"]) + assert prefixed_args == [ + command, + f"{self.global_keyprefix}fake_key" + ] + + def test_prefix_delete_args(self): + prefixed_args = self.mixin._prefix_args([ + "DEL", + "fake_key", + "fake_key2", + "fake_key3" + ]) + + assert prefixed_args == [ + "DEL", + f"{self.global_keyprefix}fake_key", + f"{self.global_keyprefix}fake_key2", + f"{self.global_keyprefix}fake_key3", + ] + + def test_prefix_brpop_args(self): + prefixed_args = self.mixin._prefix_args([ + "BRPOP", + "fake_key", + "fake_key2", + "not_prefixed" + ]) + + assert prefixed_args == [ + "BRPOP", + f"{self.global_keyprefix}fake_key", + f"{self.global_keyprefix}fake_key2", + "not_prefixed", + ] + + def test_prefix_evalsha_args(self): + prefixed_args = self.mixin._prefix_args([ + "EVALSHA", + "not_prefixed", + "not_prefixed", + "fake_key", + "not_prefixed", + ]) + + assert prefixed_args == [ + "EVALSHA", + "not_prefixed", + "not_prefixed", + f"{self.global_keyprefix}fake_key", + "not_prefixed", + ]