Skip to content

Commit

Permalink
optimize stream write
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 committed Dec 15, 2024
1 parent e8ea12c commit d712e3d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 34 deletions.
15 changes: 10 additions & 5 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(
kwargs['connection_pool'] = pool_or_conn
self.expires_extra_ms = expires_extra_ms
super().__init__(**kwargs)
self.publish_to_stream_script = self.register_script(publish_job_lua)

async def enqueue_job(
self,
Expand Down Expand Up @@ -194,10 +193,16 @@ async def enqueue_job(
else:
stream_key = _queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job_id
await self.publish_to_stream_script(
keys=[stream_key, job_message_id_key],
args=[job_id, str(enqueue_time_ms), str(expires_ms)],
client=pipe,
pipe.eval(
publish_job_lua,
2,
# keys
stream_key,
job_message_id_key,
# args
job_id,
str(enqueue_time_ms),
str(expires_ms),
)

try:
Expand Down
11 changes: 6 additions & 5 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Job:
Holds data a reference to a job.
"""

__slots__ = 'job_id', '_redis', '_queue_name', '_deserializer', '_get_job_from_stream_script'
__slots__ = 'job_id', '_redis', '_queue_name', '_deserializer'

def __init__(
self,
Expand All @@ -95,7 +95,6 @@ def __init__(
self._redis = redis
self._queue_name = _queue_name
self._deserializer = _deserializer
self._get_job_from_stream_script = redis.register_script(get_job_from_stream_lua)

async def result(
self, timeout: Optional[float] = None, *, poll_delay: float = 0.5, pole_delay: Optional[float] = None
Expand Down Expand Up @@ -152,9 +151,11 @@ async def info(self) -> Optional[JobDef]:
if info:
async with self._redis.pipeline(transaction=True) as tr:
tr.zscore(self._queue_name, self.job_id)
await self._get_job_from_stream_script(
keys=[self._queue_name + stream_key_suffix, job_message_id_prefix + self.job_id],
client=tr,
tr.eval(
get_job_from_stream_lua,
2,
self._queue_name + stream_key_suffix,
job_message_id_prefix + self.job_id,
)
delayed_score, job_info = await tr.execute()

Expand Down
63 changes: 39 additions & 24 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from uuid import uuid4

from redis.asyncio.client import Pipeline
from redis.exceptions import ResponseError, WatchError

from arq.cron import CronJob
Expand All @@ -33,7 +34,7 @@
retry_key_prefix,
stream_key_suffix,
)
from .lua import publish_delayed_job_lua
from .lua import publish_delayed_job_lua, publish_job_lua
from .utils import (
args_to_string,
import_string,
Expand Down Expand Up @@ -592,7 +593,9 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None:
ongoing_exists = await pipe.exists(in_progress_key)

if ongoing_exists:
await self._unclaim_job(job)
await pipe.unwatch()
await self._unclaim_job(job, pipe)
await pipe.execute()
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('job %s already running elsewhere', job_id)
Expand All @@ -604,6 +607,9 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None:
await pipe.execute()
except (ResponseError, WatchError):
# job already started elsewhere since we got 'existing'
pipe.multi()
await self._unclaim_job(job, pipe)
await pipe.execute()
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('multi-exec error, job %s already started elsewhere', job_id)
Expand All @@ -612,24 +618,27 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None:
t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete())
self.tasks[job_id] = t

async def _unclaim_job(self, job: JobMetaInfo) -> None:
async with self.pool.pipeline(transaction=True) as pipe:
stream_key = self.queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job.job_id

pipe.xack(stream_key, self.consumer_group_name, job.message_id)
pipe.xdel(stream_key, job.message_id)
job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms
if job_message_id_expire <= 0:
job_message_id_expire = self.expires_extra_ms

await self.pool.publish_to_stream_script(
keys=[stream_key, job_message_id_key],
args=[job.job_id, str(job.score), str(job_message_id_expire)],
client=pipe,
)

await pipe.execute()
async def _unclaim_job(self, job: JobMetaInfo, pipe: Pipeline) -> None:
stream_key = self.queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job.job_id

pipe.xack(stream_key, self.consumer_group_name, job.message_id)
pipe.xdel(stream_key, job.message_id)
job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms
if job_message_id_expire <= 0:
job_message_id_expire = self.expires_extra_ms

pipe.eval(
publish_job_lua,
2,
# keys
stream_key,
job_message_id_key,
# args
job.job_id,
str(job.score),
str(job_message_id_expire),
)

async def run_job(self, job_id: str, message_id: str, score: int) -> None: # noqa: C901
start_ms = timestamp_ms()
Expand Down Expand Up @@ -877,10 +886,16 @@ async def finish_job(
tr.zadd(self.queue_name, {job_id: score + incr_score})
else:
job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms
await self.pool.publish_to_stream_script(
keys=[stream_key, job_message_id_key],
args=[job_id, str(score), str(job_message_id_expire)],
client=tr,
tr.eval(
publish_job_lua,
2,
# keys
stream_key,
job_message_id_key,
# args
job_id,
str(score),
str(job_message_id_expire),
)
if delete_keys:
tr.delete(*delete_keys)
Expand Down

0 comments on commit d712e3d

Please sign in to comment.