Skip to content

Commit fbcb1b8

Browse files
authored
Added second-based scheduler intervals. (#450)
1 parent 23422c7 commit fbcb1b8

File tree

8 files changed

+161
-79
lines changed

8 files changed

+161
-79
lines changed

taskiq/api/scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
from datetime import timedelta
2+
from typing import Optional
3+
14
from taskiq.cli.scheduler.run import run_scheduler_loop
25
from taskiq.scheduler.scheduler import TaskiqScheduler
36

47

58
async def run_scheduler_task(
69
scheduler: TaskiqScheduler,
710
run_startup: bool = False,
11+
interval: Optional[timedelta] = None,
812
) -> None:
913
"""
1014
Run scheduler task.
@@ -20,4 +24,4 @@ async def run_scheduler_task(
2024
if run_startup:
2125
await scheduler.startup()
2226
while True:
23-
await run_scheduler_loop(scheduler)
27+
await run_scheduler_loop(scheduler, interval)

taskiq/cli/scheduler/args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class SchedulerArgs:
1717
fs_discover: bool = False
1818
tasks_pattern: Sequence[str] = ("**/tasks.py",)
1919
skip_first_run: bool = False
20+
update_interval: Optional[int] = None
2021

2122
@classmethod
2223
def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs":
@@ -80,6 +81,15 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs":
8081
"This option skips running tasks immediately after scheduler start."
8182
),
8283
)
84+
parser.add_argument(
85+
"--update-interval",
86+
type=int,
87+
default=None,
88+
help=(
89+
"Interval in seconds to check for new tasks. "
90+
"If not specified, scheduler will run once a minute."
91+
),
92+
)
8393

8494
namespace = parser.parse_args(args)
8595
# If there are any patterns specified, remove default.

taskiq/cli/scheduler/run.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from datetime import datetime, timedelta
55
from logging import basicConfig, getLevelName, getLogger
6-
from typing import Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Set, Tuple
77

88
import pytz
99
from pycron import is_now
@@ -55,7 +55,7 @@ async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]:
5555

5656
async def get_all_schedules(
5757
scheduler: TaskiqScheduler,
58-
) -> Dict[ScheduleSource, List[ScheduledTask]]:
58+
) -> List[Tuple[ScheduleSource, List[ScheduledTask]]]:
5959
"""
6060
Task to update all schedules.
6161
@@ -71,7 +71,7 @@ async def get_all_schedules(
7171
schedules = await asyncio.gather(
7272
*[get_schedules(source) for source in scheduler.sources],
7373
)
74-
return dict(zip(scheduler.sources, schedules))
74+
return list(zip(scheduler.sources, schedules))
7575

7676

7777
def get_task_delay(task: ScheduledTask) -> Optional[int]:
@@ -98,12 +98,10 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:
9898
task_time = to_tz_aware(task.time)
9999
if task_time <= now:
100100
return 0
101-
one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0)
102-
if task_time <= one_min_ahead:
103-
delay = task_time - now
104-
if delay.microseconds:
105-
return int(delay.total_seconds()) + 1
106-
return int(delay.total_seconds())
101+
delay = task_time - now
102+
if delay.microseconds:
103+
return int(delay.total_seconds()) + 1
104+
return int(delay.total_seconds())
107105
return None
108106

109107

@@ -145,21 +143,41 @@ async def delayed_send(
145143
await scheduler.on_ready(source, task)
146144

147145

148-
async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
146+
async def run_scheduler_loop( # noqa: C901
147+
scheduler: TaskiqScheduler,
148+
interval: Optional[timedelta] = None,
149+
) -> None:
149150
"""
150151
Runs scheduler loop.
151152
152153
This function imports taskiq scheduler
153154
and runs tasks when needed.
154155
155156
:param scheduler: current scheduler.
157+
:param interval: interval to check for schedule updates.
156158
"""
157159
loop = asyncio.get_event_loop()
158-
running_schedules = set()
160+
running_schedules: Dict[str, asyncio.Task[Any]] = {}
161+
ran_cron_jobs: Set[str] = set()
162+
current_minute = datetime.now(tz=pytz.UTC).minute
159163
while True:
160-
# We use this method to correctly sleep for one minute.
164+
now = datetime.now(tz=pytz.UTC)
165+
# If minute changed, we need to clear
166+
# ran_cron_jobs set and update current minute.
167+
if now.minute != current_minute:
168+
current_minute = now.minute
169+
ran_cron_jobs.clear()
170+
# If interval is not None, we need to
171+
# calculate next run time using it.
172+
if interval is not None:
173+
next_run = now + interval
174+
# otherwise we need assume that
175+
# we will run it at the start of the next minute.
176+
# as crontab does.
177+
else:
178+
next_run = (now + timedelta(minutes=1)).replace(second=1, microsecond=0)
161179
scheduled_tasks = await get_all_schedules(scheduler)
162-
for source, task_list in scheduled_tasks.items():
180+
for source, task_list in scheduled_tasks:
163181
logger.debug("Got %d schedules from source %s.", len(task_list), source)
164182
for task in task_list:
165183
try:
@@ -172,16 +190,37 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
172190
task.schedule_id,
173191
)
174192
continue
175-
if task_delay is not None:
176-
send_task = loop.create_task(
177-
delayed_send(scheduler, source, task, task_delay),
178-
)
179-
running_schedules.add(send_task)
180-
send_task.add_done_callback(running_schedules.discard)
181-
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
182-
minutes=1,
183-
)
184-
delay = next_minute - datetime.now()
193+
# If task delay is None, we don't need to run it.
194+
if task_delay is None:
195+
continue
196+
# If task is delayed for more than next_run,
197+
# we don't need to run it, because we will
198+
# run it in the next iteration.
199+
if now + timedelta(seconds=task_delay) >= next_run:
200+
continue
201+
# If task is already running, we don't need to run it again.
202+
if task.schedule_id in running_schedules and task_delay < 1:
203+
continue
204+
# If task is cron job, we need to check if
205+
# we already ran it this minute.
206+
if task.cron is not None:
207+
if task.schedule_id in ran_cron_jobs:
208+
continue
209+
ran_cron_jobs.add(task.schedule_id)
210+
send_task = loop.create_task(
211+
delayed_send(scheduler, source, task, task_delay),
212+
# We need to set the name of the task
213+
# to be able to discard its reference
214+
# after it is done.
215+
name=f"schedule_{task.schedule_id}",
216+
)
217+
running_schedules[task.schedule_id] = send_task
218+
send_task.add_done_callback(
219+
lambda task_future: running_schedules.pop(
220+
task_future.get_name().removeprefix("schedule_"),
221+
),
222+
)
223+
delay = next_run - datetime.now(tz=pytz.UTC)
185224
logger.debug(
186225
"Sleeping for %.2f seconds before getting schedules.",
187226
delay.total_seconds(),
@@ -226,6 +265,10 @@ async def run_scheduler(args: SchedulerArgs) -> None:
226265
for source in scheduler.sources:
227266
await source.startup()
228267

268+
interval = None
269+
if args.update_interval:
270+
interval = timedelta(seconds=args.update_interval)
271+
229272
logger.info("Starting scheduler.")
230273
await scheduler.startup()
231274
logger.info("Startup completed.")
@@ -239,7 +282,7 @@ async def run_scheduler(args: SchedulerArgs) -> None:
239282
await asyncio.sleep(delay.total_seconds())
240283
logger.info("First run skipped. The scheduler is now running.")
241284
try:
242-
await run_scheduler_loop(scheduler)
285+
await run_scheduler_loop(scheduler, interval)
243286
except asyncio.CancelledError:
244287
logger.warning("Shutting down scheduler.")
245288
await scheduler.shutdown()
Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import uuid
12
from logging import getLogger
2-
from typing import List
3+
from typing import Dict, List
34

45
from taskiq.abc.broker import AsyncBroker
56
from taskiq.abc.schedule_source import ScheduleSource
@@ -13,20 +14,26 @@ class LabelScheduleSource(ScheduleSource):
1314

1415
def __init__(self, broker: AsyncBroker) -> None:
1516
self.broker = broker
17+
self.schedules: Dict[str, ScheduledTask] = {}
1618

17-
async def get_schedules(self) -> List["ScheduledTask"]:
19+
async def startup(self) -> None:
1820
"""
19-
Collect schedules for all tasks.
20-
21-
this function checks labels for all
22-
tasks available to the broker.
21+
Startup the schedule source.
2322
23+
This function iterates over all tasks
24+
available to the broker and collects
25+
schedules from their labels.
2426
If task has a schedule label,
25-
it will be parsed and returned.
27+
it will be parsed and added to the
28+
scheduler list.
2629
27-
:return: list of schedules.
30+
Every time schedule is added, the random
31+
schedule id is generated. Please be aware that
32+
they are different for every startup.
33+
34+
:return: None
2835
"""
29-
schedules = []
36+
self.schedules.clear()
3037
for task_name, task in self.broker.get_all_tasks().items():
3138
if task.broker != self.broker:
3239
# if task broker doesn't match self, something is probably wrong
@@ -40,20 +47,36 @@ async def get_schedules(self) -> List["ScheduledTask"]:
4047
continue
4148
labels = schedule.get("labels", {})
4249
labels.update(task.labels)
43-
schedules.append(
44-
ScheduledTask(
45-
task_name=task_name,
46-
labels=labels,
47-
args=schedule.get("args", []),
48-
kwargs=schedule.get("kwargs", {}),
49-
cron=schedule.get("cron"),
50-
time=schedule.get("time"),
51-
cron_offset=schedule.get("cron_offset"),
52-
),
50+
schedule_id = uuid.uuid4().hex
51+
52+
self.schedules[schedule_id] = ScheduledTask(
53+
task_name=task_name,
54+
labels=labels,
55+
schedule_id=schedule_id,
56+
args=schedule.get("args", []),
57+
kwargs=schedule.get("kwargs", {}),
58+
cron=schedule.get("cron"),
59+
time=schedule.get("time"),
60+
cron_offset=schedule.get("cron_offset"),
5361
)
54-
return schedules
5562

56-
def post_send(self, scheduled_task: ScheduledTask) -> None:
63+
return await super().startup()
64+
65+
async def get_schedules(self) -> List["ScheduledTask"]:
66+
"""
67+
Collect schedules for all tasks.
68+
69+
this function checks labels for all
70+
tasks available to the broker.
71+
72+
If task has a schedule label,
73+
it will be parsed and returned.
74+
75+
:return: list of schedules.
76+
"""
77+
return list(self.schedules.values())
78+
79+
def post_send(self, task: "ScheduledTask") -> None:
5780
"""
5881
Remove `time` schedule from task's scheduler list.
5982
@@ -62,22 +85,7 @@ def post_send(self, scheduled_task: ScheduledTask) -> None:
6285
6386
:param scheduled_task: task that just have sent
6487
"""
65-
if scheduled_task.cron or not scheduled_task.time:
88+
if task.cron or not task.time:
6689
return # it's scheduled task with cron label, do not remove this trigger.
6790

68-
for task_name, task in self.broker.get_all_tasks().items():
69-
if task.broker != self.broker:
70-
# if task broker doesn't match self, something is probably wrong
71-
logger.warning(
72-
f"Broker for {task_name} `{task.broker}` doesn't "
73-
f"match scheduler's broker `{self.broker}`",
74-
)
75-
continue
76-
if scheduled_task.task_name != task_name:
77-
continue
78-
79-
schedule_list = task.labels.get("schedule", []).copy()
80-
for idx, schedule in enumerate(schedule_list):
81-
if schedule.get("time") == scheduled_task.time:
82-
task.labels.get("schedule", []).pop(idx)
83-
return
91+
self.schedules.pop(task.schedule_id, None)

tests/cli/scheduler/test_task_delays.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def test_should_run_success() -> None:
12-
hour = datetime.datetime.utcnow().hour
12+
hour = datetime.datetime.now(datetime.timezone.utc).hour
1313
delay = get_task_delay(
1414
ScheduledTask(
1515
task_name="",
@@ -97,18 +97,26 @@ def test_time_utc_with_local_zone() -> None:
9797
assert delay is not None and delay >= 0
9898

9999

100+
@freeze_time("2023-01-14 12:00:00")
100101
def test_time_localtime_without_zone() -> None:
101102
time = datetime.datetime.now(tz=pytz.FixedOffset(240)).replace(tzinfo=None)
103+
time_to_run = time - datetime.timedelta(seconds=1)
104+
102105
delay = get_task_delay(
103106
ScheduledTask(
104107
task_name="",
105108
labels={},
106109
args=[],
107110
kwargs={},
108-
time=time - datetime.timedelta(seconds=1),
111+
time=time_to_run,
109112
),
110113
)
111-
assert delay is None
114+
115+
expected_delay = time_to_run.replace(tzinfo=pytz.UTC) - datetime.datetime.now(
116+
pytz.UTC,
117+
)
118+
119+
assert delay == int(expected_delay.total_seconds())
112120

113121

114122
@freeze_time("2023-01-14 12:00:00")

tests/cli/scheduler/test_updater.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ async def test_get_schedules_success() -> None:
5656
schedules = await get_all_schedules(
5757
TaskiqScheduler(InMemoryBroker(), sources),
5858
)
59-
assert schedules == {
60-
sources[0]: schedules1,
61-
sources[1]: schedules2,
62-
}
59+
assert schedules == [
60+
(sources[0], schedules1),
61+
(sources[1], schedules2),
62+
]
6363

6464

6565
@pytest.mark.anyio
@@ -81,7 +81,7 @@ async def test_get_schedules_error() -> None:
8181
schedules = await get_all_schedules(
8282
TaskiqScheduler(InMemoryBroker(), [source1, source2]),
8383
)
84-
assert schedules == {
85-
source1: source1.schedules,
86-
source2: [],
87-
}
84+
assert schedules == [
85+
(source1, source1.schedules),
86+
(source2, []),
87+
]

0 commit comments

Comments
 (0)