diff --git a/taskiq/api/scheduler.py b/taskiq/api/scheduler.py index 97c59c78..6928b128 100644 --- a/taskiq/api/scheduler.py +++ b/taskiq/api/scheduler.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from typing import Optional + from taskiq.cli.scheduler.run import run_scheduler_loop from taskiq.scheduler.scheduler import TaskiqScheduler @@ -5,6 +8,7 @@ async def run_scheduler_task( scheduler: TaskiqScheduler, run_startup: bool = False, + interval: Optional[timedelta] = None, ) -> None: """ Run scheduler task. @@ -20,4 +24,4 @@ async def run_scheduler_task( if run_startup: await scheduler.startup() while True: - await run_scheduler_loop(scheduler) + await run_scheduler_loop(scheduler, interval) diff --git a/taskiq/cli/scheduler/args.py b/taskiq/cli/scheduler/args.py index 1850f360..d1f6d821 100644 --- a/taskiq/cli/scheduler/args.py +++ b/taskiq/cli/scheduler/args.py @@ -17,6 +17,7 @@ class SchedulerArgs: fs_discover: bool = False tasks_pattern: Sequence[str] = ("**/tasks.py",) skip_first_run: bool = False + update_interval: Optional[int] = None @classmethod def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": @@ -80,6 +81,15 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": "This option skips running tasks immediately after scheduler start." ), ) + parser.add_argument( + "--update-interval", + type=int, + default=None, + help=( + "Interval in seconds to check for new tasks. " + "If not specified, scheduler will run once a minute." + ), + ) namespace = parser.parse_args(args) # If there are any patterns specified, remove default. diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index b1487599..09467dc0 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timedelta from logging import basicConfig, getLevelName, getLogger -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Set, Tuple import pytz from pycron import is_now @@ -55,7 +55,7 @@ async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]: async def get_all_schedules( scheduler: TaskiqScheduler, -) -> Dict[ScheduleSource, List[ScheduledTask]]: +) -> List[Tuple[ScheduleSource, List[ScheduledTask]]]: """ Task to update all schedules. @@ -71,7 +71,7 @@ async def get_all_schedules( schedules = await asyncio.gather( *[get_schedules(source) for source in scheduler.sources], ) - return dict(zip(scheduler.sources, schedules)) + return list(zip(scheduler.sources, schedules)) def get_task_delay(task: ScheduledTask) -> Optional[int]: @@ -98,12 +98,10 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]: task_time = to_tz_aware(task.time) if task_time <= now: return 0 - one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0) - if task_time <= one_min_ahead: - delay = task_time - now - if delay.microseconds: - return int(delay.total_seconds()) + 1 - return int(delay.total_seconds()) + delay = task_time - now + if delay.microseconds: + return int(delay.total_seconds()) + 1 + return int(delay.total_seconds()) return None @@ -145,7 +143,10 @@ async def delayed_send( await scheduler.on_ready(source, task) -async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: +async def run_scheduler_loop( # noqa: C901 + scheduler: TaskiqScheduler, + interval: Optional[timedelta] = None, +) -> None: """ Runs scheduler loop. @@ -153,13 +154,30 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: and runs tasks when needed. :param scheduler: current scheduler. + :param interval: interval to check for schedule updates. """ loop = asyncio.get_event_loop() - running_schedules = set() + running_schedules: Dict[str, asyncio.Task[Any]] = {} + ran_cron_jobs: Set[str] = set() + current_minute = datetime.now(tz=pytz.UTC).minute while True: - # We use this method to correctly sleep for one minute. + now = datetime.now(tz=pytz.UTC) + # If minute changed, we need to clear + # ran_cron_jobs set and update current minute. + if now.minute != current_minute: + current_minute = now.minute + ran_cron_jobs.clear() + # If interval is not None, we need to + # calculate next run time using it. + if interval is not None: + next_run = now + interval + # otherwise we need assume that + # we will run it at the start of the next minute. + # as crontab does. + else: + next_run = (now + timedelta(minutes=1)).replace(second=1, microsecond=0) scheduled_tasks = await get_all_schedules(scheduler) - for source, task_list in scheduled_tasks.items(): + for source, task_list in scheduled_tasks: logger.debug("Got %d schedules from source %s.", len(task_list), source) for task in task_list: try: @@ -172,16 +190,37 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: task.schedule_id, ) continue - if task_delay is not None: - send_task = loop.create_task( - delayed_send(scheduler, source, task, task_delay), - ) - running_schedules.add(send_task) - send_task.add_done_callback(running_schedules.discard) - next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta( - minutes=1, - ) - delay = next_minute - datetime.now() + # If task delay is None, we don't need to run it. + if task_delay is None: + continue + # If task is delayed for more than next_run, + # we don't need to run it, because we will + # run it in the next iteration. + if now + timedelta(seconds=task_delay) >= next_run: + continue + # If task is already running, we don't need to run it again. + if task.schedule_id in running_schedules and task_delay < 1: + continue + # If task is cron job, we need to check if + # we already ran it this minute. + if task.cron is not None: + if task.schedule_id in ran_cron_jobs: + continue + ran_cron_jobs.add(task.schedule_id) + send_task = loop.create_task( + delayed_send(scheduler, source, task, task_delay), + # We need to set the name of the task + # to be able to discard its reference + # after it is done. + name=f"schedule_{task.schedule_id}", + ) + running_schedules[task.schedule_id] = send_task + send_task.add_done_callback( + lambda task_future: running_schedules.pop( + task_future.get_name().removeprefix("schedule_"), + ), + ) + delay = next_run - datetime.now(tz=pytz.UTC) logger.debug( "Sleeping for %.2f seconds before getting schedules.", delay.total_seconds(), @@ -226,6 +265,10 @@ async def run_scheduler(args: SchedulerArgs) -> None: for source in scheduler.sources: await source.startup() + interval = None + if args.update_interval: + interval = timedelta(seconds=args.update_interval) + logger.info("Starting scheduler.") await scheduler.startup() logger.info("Startup completed.") @@ -239,7 +282,7 @@ async def run_scheduler(args: SchedulerArgs) -> None: await asyncio.sleep(delay.total_seconds()) logger.info("First run skipped. The scheduler is now running.") try: - await run_scheduler_loop(scheduler) + await run_scheduler_loop(scheduler, interval) except asyncio.CancelledError: logger.warning("Shutting down scheduler.") await scheduler.shutdown() diff --git a/taskiq/schedule_sources/label_based.py b/taskiq/schedule_sources/label_based.py index 1f313717..94fd42de 100644 --- a/taskiq/schedule_sources/label_based.py +++ b/taskiq/schedule_sources/label_based.py @@ -1,5 +1,6 @@ +import uuid from logging import getLogger -from typing import List +from typing import Dict, List from taskiq.abc.broker import AsyncBroker from taskiq.abc.schedule_source import ScheduleSource @@ -13,20 +14,26 @@ class LabelScheduleSource(ScheduleSource): def __init__(self, broker: AsyncBroker) -> None: self.broker = broker + self.schedules: Dict[str, ScheduledTask] = {} - async def get_schedules(self) -> List["ScheduledTask"]: + async def startup(self) -> None: """ - Collect schedules for all tasks. - - this function checks labels for all - tasks available to the broker. + Startup the schedule source. + This function iterates over all tasks + available to the broker and collects + schedules from their labels. If task has a schedule label, - it will be parsed and returned. + it will be parsed and added to the + scheduler list. - :return: list of schedules. + Every time schedule is added, the random + schedule id is generated. Please be aware that + they are different for every startup. + + :return: None """ - schedules = [] + self.schedules.clear() for task_name, task in self.broker.get_all_tasks().items(): if task.broker != self.broker: # if task broker doesn't match self, something is probably wrong @@ -40,20 +47,36 @@ async def get_schedules(self) -> List["ScheduledTask"]: continue labels = schedule.get("labels", {}) labels.update(task.labels) - schedules.append( - ScheduledTask( - task_name=task_name, - labels=labels, - args=schedule.get("args", []), - kwargs=schedule.get("kwargs", {}), - cron=schedule.get("cron"), - time=schedule.get("time"), - cron_offset=schedule.get("cron_offset"), - ), + schedule_id = uuid.uuid4().hex + + self.schedules[schedule_id] = ScheduledTask( + task_name=task_name, + labels=labels, + schedule_id=schedule_id, + args=schedule.get("args", []), + kwargs=schedule.get("kwargs", {}), + cron=schedule.get("cron"), + time=schedule.get("time"), + cron_offset=schedule.get("cron_offset"), ) - return schedules - def post_send(self, scheduled_task: ScheduledTask) -> None: + return await super().startup() + + async def get_schedules(self) -> List["ScheduledTask"]: + """ + Collect schedules for all tasks. + + this function checks labels for all + tasks available to the broker. + + If task has a schedule label, + it will be parsed and returned. + + :return: list of schedules. + """ + return list(self.schedules.values()) + + def post_send(self, task: "ScheduledTask") -> None: """ Remove `time` schedule from task's scheduler list. @@ -62,22 +85,7 @@ def post_send(self, scheduled_task: ScheduledTask) -> None: :param scheduled_task: task that just have sent """ - if scheduled_task.cron or not scheduled_task.time: + if task.cron or not task.time: return # it's scheduled task with cron label, do not remove this trigger. - for task_name, task in self.broker.get_all_tasks().items(): - if task.broker != self.broker: - # if task broker doesn't match self, something is probably wrong - logger.warning( - f"Broker for {task_name} `{task.broker}` doesn't " - f"match scheduler's broker `{self.broker}`", - ) - continue - if scheduled_task.task_name != task_name: - continue - - schedule_list = task.labels.get("schedule", []).copy() - for idx, schedule in enumerate(schedule_list): - if schedule.get("time") == scheduled_task.time: - task.labels.get("schedule", []).pop(idx) - return + self.schedules.pop(task.schedule_id, None) diff --git a/tests/cli/scheduler/test_task_delays.py b/tests/cli/scheduler/test_task_delays.py index 1af3c783..2e00fe37 100644 --- a/tests/cli/scheduler/test_task_delays.py +++ b/tests/cli/scheduler/test_task_delays.py @@ -9,7 +9,7 @@ def test_should_run_success() -> None: - hour = datetime.datetime.utcnow().hour + hour = datetime.datetime.now(datetime.timezone.utc).hour delay = get_task_delay( ScheduledTask( task_name="", @@ -97,18 +97,26 @@ def test_time_utc_with_local_zone() -> None: assert delay is not None and delay >= 0 +@freeze_time("2023-01-14 12:00:00") def test_time_localtime_without_zone() -> None: time = datetime.datetime.now(tz=pytz.FixedOffset(240)).replace(tzinfo=None) + time_to_run = time - datetime.timedelta(seconds=1) + delay = get_task_delay( ScheduledTask( task_name="", labels={}, args=[], kwargs={}, - time=time - datetime.timedelta(seconds=1), + time=time_to_run, ), ) - assert delay is None + + expected_delay = time_to_run.replace(tzinfo=pytz.UTC) - datetime.datetime.now( + pytz.UTC, + ) + + assert delay == int(expected_delay.total_seconds()) @freeze_time("2023-01-14 12:00:00") diff --git a/tests/cli/scheduler/test_updater.py b/tests/cli/scheduler/test_updater.py index c2a7b9e5..2ac9ef8d 100644 --- a/tests/cli/scheduler/test_updater.py +++ b/tests/cli/scheduler/test_updater.py @@ -56,10 +56,10 @@ async def test_get_schedules_success() -> None: schedules = await get_all_schedules( TaskiqScheduler(InMemoryBroker(), sources), ) - assert schedules == { - sources[0]: schedules1, - sources[1]: schedules2, - } + assert schedules == [ + (sources[0], schedules1), + (sources[1], schedules2), + ] @pytest.mark.anyio @@ -81,7 +81,7 @@ async def test_get_schedules_error() -> None: schedules = await get_all_schedules( TaskiqScheduler(InMemoryBroker(), [source1, source2]), ) - assert schedules == { - source1: source1.schedules, - source2: [], - } + assert schedules == [ + (source1, source1.schedules), + (source2, []), + ] diff --git a/tests/schedule_sources/test_label_based.py b/tests/schedule_sources/test_label_based.py index 9e683917..fa621b07 100644 --- a/tests/schedule_sources/test_label_based.py +++ b/tests/schedule_sources/test_label_based.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import pytest +import pytz from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.schedule_sources.label_based import LabelScheduleSource @@ -13,7 +14,7 @@ "schedule_label", [ pytest.param([{"cron": "* * * * *"}], id="cron"), - pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: @@ -27,6 +28,7 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [ ScheduledTask( @@ -53,5 +55,6 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [] diff --git a/tests/scheduler/test_label_based_sched.py b/tests/scheduler/test_label_based_sched.py index 156e8498..506990b0 100644 --- a/tests/scheduler/test_label_based_sched.py +++ b/tests/scheduler/test_label_based_sched.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List import pytest +import pytz from freezegun import freeze_time from taskiq.brokers.inmemory_broker import InMemoryBroker @@ -18,7 +19,7 @@ "schedule_label", [ pytest.param([{"cron": "* * * * *"}], id="cron"), - pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: @@ -31,7 +32,9 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: def task() -> None: pass - schedules = await LabelScheduleSource(broker).get_schedules() + source = LabelScheduleSource(broker) + await source.startup() + schedules = await source.get_schedules() assert schedules == [ ScheduledTask( schedule_id=schedules[0].schedule_id, @@ -57,6 +60,7 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [] @@ -69,6 +73,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None: broker=broker, sources=[LabelScheduleSource(broker)], ) + for source in scheduler.sources: + await source.startup() # NOTE: # freeze time to 00:00, so task won't be scheduled by `cron`, only by `time` @@ -77,8 +83,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None: @broker.task( task_name="test_task", schedule=[ - {"time": datetime.utcnow(), "args": [1]}, - {"time": datetime.utcnow() + timedelta(days=1), "args": [2]}, + {"time": datetime.now(pytz.UTC), "args": [1]}, + {"time": datetime.now(pytz.UTC) + timedelta(days=1), "args": [2]}, {"cron": "1 * * * *", "args": [3]}, ], )