Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 124 additions & 50 deletions faststream/redis/subscriber/usecases/stream_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,51 @@ async def _consume(self, *args: Any, start_signal: "Event") -> None:
start_signal.set()
await super()._consume(*args, start_signal=start_signal)

async def _create_group(self, reset_counter:bool = False) -> None:
if reset_counter:
group_create_id = "0"
else:
group_create_id = "$" if self.last_id == ">" else self.last_id
try:
await self._client.xgroup_create(
name=self.stream_sub.name,
id=group_create_id,
groupname=self.stream_sub.group,
mkstream=True,
)
except ResponseError as e:
if "already exists" not in str(e):
raise

def _protect_read_from_group_removal(
self,
read_func: Callable[[], Awaitable[ReadResponse]],
stream: "StreamSub",
) -> Callable[[], Awaitable[ReadResponse]]:
async def _read_from_group_removal() -> ReadResponse:
try:
return await read_func()
except ResponseError as e:
err_msg = str(e)
known_error:bool = False
if "NOGROUP" in err_msg:
# most likely redis was flushed, so we need to reset our group
await self._create_group(reset_counter=True)
# Important: reset our internal position too
stream.last_id = ">"
known_error = True
if (
"smaller than the first available entry" in err_msg
or "greater than the maximum id" in err_msg
):
# group was modified by third party and we need to reset our position to an existing id
stream.last_id = "$"
known_error = True
if known_error:
return await read_func()
raise e
return _read_from_group_removal

@override
async def start(self) -> None:
client = self._client
Expand Down Expand Up @@ -112,10 +157,7 @@ async def start(self) -> None:
raise

if stream.min_idle_time is None:

def read(
_: str,
) -> Awaitable[ReadResponse]:
def _xreadgroup_call() -> Awaitable[ReadResponse]:
return client.xreadgroup(
groupname=stream.group,
consumername=stream.consumer,
Expand All @@ -125,17 +167,34 @@ def read(
noack=stream.no_ack,
)

else:
protected_read_func = self._protect_read_from_group_removal(
read_func=_xreadgroup_call,
stream=stream,
)

async def read(_: str) -> ReadResponse:
stream_message = await client.xautoclaim(
async def read(
_: str,
) -> ReadResponse:
return await protected_read_func()
else:
def _xautoclaim_call() -> Awaitable[Any]:
return client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_xautoclaim_call,
stream=stream,
)

async def read(_: str) -> ReadResponse:
stream_message = await protected_autoclaim()

stream_name = self.stream_sub.name.encode()
(next_id, messages, _) = stream_message

Expand All @@ -149,7 +208,6 @@ async def read(_: str) -> ReadResponse:
return ((stream_name, messages),)

else:

def read(
last_id: str,
) -> Awaitable[ReadResponse]:
Expand All @@ -161,6 +219,30 @@ def read(

await super().start(read)

async def _get_one_message(self, timeout: float) -> Optional[ReadResponse]:
if self.stream_sub.group and self.stream_sub.consumer:
def _readgroup_call() -> Awaitable[ReadResponse]:
return self._client.xreadgroup(
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)

protected_read = self._protect_read_from_group_removal(
read_func=_readgroup_call,
stream=self.stream_sub,
)
stream_message = await protected_read() # <-- Appel et attente de la fonction protégée
else:
stream_message = await self._client.xread(
{self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
return stream_message

@override
async def get_one(
self,
Expand All @@ -170,34 +252,31 @@ async def get_one(
assert not self.calls, (
"You can't use `get_one` method if subscriber has registered handlers."
)
if self.min_idle_time is None:
if self.stream_sub.group and self.stream_sub.consumer:
stream_message = await self._client.xreadgroup(
if self.min_idle_time is None:# utilise _get_one_message corrigé ci-dessus
stream_message = await self._get_one_message(timeout)
if not stream_message:
return None

((stream_name, ((message_id, raw_message),)),) = stream_message
else:
def _autoclaim_call() -> Awaitable[Any]:
return self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
else:
stream_message = await self._client.xread(
{self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_autoclaim_call,
stream=self.stream_sub,
)
stream_message = await protected_autoclaim()
if not stream_message:
return None

((stream_name, ((message_id, raw_message),)),) = stream_message
else:
stream_message = await self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)
(next_id, messages, _) = stream_message
# Update start_id for next call
self.autoclaim_start_id = next_id
Expand Down Expand Up @@ -241,33 +320,28 @@ async def __aiter__(self) -> AsyncIterator["RedisStreamMessage"]: # type: ignor

while True:
if self.min_idle_time is None:
if self.stream_sub.group and self.stream_sub.consumer:
stream_message = await self._client.xreadgroup(
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
else:
stream_message = await self._client.xread(
{self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
stream_message = await self._get_one_message(timeout)
if not stream_message:
continue

((stream_name, ((message_id, raw_message),)),) = stream_message
else:
stream_message = await self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
def _autoclaim_call() -> Awaitable[Any]:
return self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_autoclaim_call,
stream=self.stream_sub,
)
stream_message = await protected_autoclaim()

(next_id, messages, _) = stream_message
# Update start_id for next call
self.autoclaim_start_id = next_id
Expand Down
46 changes: 46 additions & 0 deletions tests/brokers/redis/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,52 @@ async def handler(msg: RedisStreamMessage) -> None:
assert queue_len == 0, (
f"Redis stream must be empty here, found {queue_len} messages"
)
async def test_consume_from_group(
self,
queue: str,
) -> None:
event = asyncio.Event()

consume_broker = self.get_broker(apply_types=True)

@consume_broker.subscriber(
stream=StreamSub(queue, group="group", consumer=queue),
)
async def handler(msg: RedisMessage) -> None:
event.set()

async with self.patch_broker(consume_broker) as br:
await br.start()
redis_client = br._connection
with (
patch.object(redis_client, "xreadgroup", spy_decorator(redis_client.xreadgroup)) as m_readgroup,
patch.object(redis_client, "xgroup_create", spy_decorator(redis_client.xgroup_create)) as m_group_create
):
await asyncio.wait(
(
asyncio.create_task(br.publish("hello", stream=queue)),
asyncio.create_task(event.wait()),
),
timeout=3,
)
await asyncio.sleep(0.1)
m_readgroup.mock.assert_called_once()
assert event.is_set()
await redis_client.flushall()
event.clear()
await asyncio.sleep(0.1)
await asyncio.wait(
(
asyncio.create_task(br.publish("hello again", stream=queue)),
asyncio.create_task(event.wait()),
),
timeout=3,
)

await asyncio.sleep(0.1)
m_group_create.mock.assert_called_once()

assert event.is_set()

async def test_get_one(
self,
Expand Down
15 changes: 8 additions & 7 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.