-
-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support to celery beat through custom
Scheduler
s (#65)
* Include `.idea` folder to `.gitignore` * Add `TenantAwareSchedulerMixin` and subclasses The subclasses of this mixin (provided they're mixed up with `PersistentScheduler` or `Scheduler`) call the task's `delay` function inside of each tenant's context * Document usage of schedulers in beat integration * Apply suggestions from code review
- Loading branch information
Showing
4 changed files
with
292 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,3 +55,5 @@ coverage.xml | |
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# IDEs | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import logging | ||
from typing import List, Optional | ||
|
||
from celery.beat import PersistentScheduler, ScheduleEntry, Scheduler | ||
from django_tenants.utils import get_tenant_model, tenant_context, get_public_schema_name | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TenantAwareScheduleEntry(ScheduleEntry): | ||
tenant_schemas: Optional[List[str]] = None | ||
|
||
def __init__(self, *args, **kwargs): | ||
if args: | ||
# Unpickled from database | ||
self.tenant_schemas = args[-1] | ||
else: | ||
# Initialized from code | ||
self.tenant_schemas = kwargs.pop("tenant_schemas", None) | ||
|
||
super().__init__(*args, **kwargs) | ||
|
||
def update(self, other): | ||
"""Update values from another entry. | ||
Will only update `tenant_schemas` and "editable" fields: | ||
``task``, ``schedule``, ``args``, ``kwargs``, ``options``. | ||
""" | ||
vars(self).update( | ||
{ | ||
"task": other.task, | ||
"schedule": other.schedule, | ||
"args": other.args, | ||
"kwargs": other.kwargs, | ||
"options": other.options, | ||
"tenant_schemas": other.tenant_schemas, | ||
} | ||
) | ||
|
||
def __reduce__(self): | ||
"""Needed for Pickle serialization""" | ||
return self.__class__, ( | ||
self.name, | ||
self.task, | ||
self.last_run_at, | ||
self.total_run_count, | ||
self.schedule, | ||
self.args, | ||
self.kwargs, | ||
self.options, | ||
self.tenant_schemas, | ||
) | ||
|
||
def editable_fields_equal(self, other): | ||
for attr in ( | ||
"task", | ||
"args", | ||
"kwargs", | ||
"options", | ||
"schedule", | ||
"tenant_schemas", | ||
): | ||
if getattr(self, attr) != getattr(other, attr): | ||
return False | ||
return True | ||
|
||
|
||
class TenantAwareSchedulerMixin: | ||
Entry = TenantAwareScheduleEntry | ||
|
||
def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None): | ||
""" | ||
See https://github.com/celery/celery/blob/c571848023be732a1a11d46198cf831a522cfb54/celery/beat.py#L277 | ||
""" | ||
|
||
tenants = get_tenant_model().objects.all() | ||
|
||
if entry.tenant_schemas is None: | ||
tenants = tenants.exclude(schema_name=get_public_schema_name()) | ||
else: | ||
tenants = tenants.filter(schema_name__in=entry.tenant_schemas) | ||
|
||
logger.info( | ||
"TenantAwareScheduler: Sending due task %s (%s) to %s tenants", | ||
entry.name, | ||
entry.task, | ||
"all" if entry.tenant_schemas is None else str(len(tenants)), | ||
) | ||
|
||
for tenant in tenants: | ||
with tenant_context(tenant): | ||
logger.debug( | ||
"Sending due task %s (%s) to tenant %s", | ||
entry.name, | ||
entry.task, | ||
tenant.name, | ||
) | ||
try: | ||
result = self.apply_async( | ||
entry, producer=producer, advance=False | ||
) | ||
except Exception as exc: | ||
logger.exception(exc) | ||
else: | ||
logger.debug("%s sent. id->%s", entry.task, result.id) | ||
|
||
|
||
class TenantAwareScheduler(TenantAwareSchedulerMixin, Scheduler): | ||
pass | ||
|
||
|
||
class TenantAwarePersistentScheduler( | ||
TenantAwareSchedulerMixin, PersistentScheduler | ||
): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from tempfile import NamedTemporaryFile | ||
from typing import Any, List, Mapping, Optional, Tuple, TypedDict | ||
|
||
from celery import schedules, uuid | ||
from celery.beat import Scheduler | ||
from django.db import connection | ||
from django_tenants.utils import get_tenant_model, schema_context, get_public_schema_name | ||
from pytest import fixture, mark | ||
from tenant_schemas_celery.app import CeleryApp | ||
|
||
from .scheduler import ( | ||
TenantAwarePersistentScheduler, | ||
TenantAwareSchedulerMixin, | ||
) | ||
|
||
|
||
class FakeScheduler(TenantAwareSchedulerMixin, Scheduler): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._sent: List[Tuple[str, TenantAwareSchedulerMixin.Entry]] = [] | ||
|
||
def apply_async(self, entry, producer=None, advance=True, **kwargs): | ||
self._sent.append((connection.schema_name, entry)) | ||
return self.app.AsyncResult(uuid()) | ||
|
||
|
||
class ScheduledEntryConfig(TypedDict, total=False): | ||
task: str | ||
schedule: int | ||
args: Optional[Tuple[Any, ...]] | ||
kwargs: Optional[Mapping[str, Any]] | ||
options: Optional[Mapping[str, Any]] | ||
tenant_schemas: Optional[List[str]] | ||
|
||
|
||
COMMON_PARAMETERS = mark.parametrize( | ||
"config", | ||
( | ||
{ | ||
"test_task": { | ||
"task": "test_task", | ||
"schedule": schedules.crontab(minute="*"), | ||
} | ||
}, | ||
{ | ||
"test_tenant_specific_task": { | ||
"task": "tenant_specific_task", | ||
"schedule": schedules.crontab(minute="*"), | ||
"tenant_schemas": ["tenant1"], | ||
} | ||
}, | ||
{ | ||
"test_tenant_specific_task": { | ||
"task": "tenant_specific_task", | ||
"schedule": schedules.crontab(minute="*"), | ||
"tenant_schemas": ["tenant1", "tenant2"], | ||
}, | ||
"test_generic_task": { | ||
"task": "generic_task", | ||
"schedule": schedules.crontab(minute="*"), | ||
}, | ||
}, | ||
), | ||
) | ||
|
||
|
||
@fixture | ||
def app(config: Mapping[str, ScheduledEntryConfig]) -> CeleryApp: | ||
app = CeleryApp("test_app", set_as_current=False) | ||
app.conf.beat_schedule = config | ||
return app | ||
|
||
|
||
@COMMON_PARAMETERS | ||
class TestTenantAwareSchedulerMixin: | ||
@fixture | ||
def scheduler(self, app: CeleryApp) -> FakeScheduler: | ||
return FakeScheduler(app) | ||
|
||
def test_schedule_setup_properly( | ||
self, | ||
scheduler: FakeScheduler, | ||
config: Mapping[str, ScheduledEntryConfig], | ||
): | ||
for key, config in config.items(): | ||
assert key in scheduler.schedule | ||
entry = scheduler.schedule[key] | ||
|
||
assert entry.task == config["task"] | ||
assert entry.schedule == schedules.crontab(minute="*") | ||
assert entry.tenant_schemas == config.get("tenant_schemas", None) | ||
|
||
@fixture | ||
def tenants(self) -> None: | ||
with schema_context(get_public_schema_name()): | ||
get_tenant_model().objects.create( | ||
name="Tenant1", schema_name="tenant1" | ||
) | ||
get_tenant_model().objects.create( | ||
name="Tenant2", schema_name="tenant2" | ||
) | ||
|
||
@mark.django_db | ||
def test_apply_entry(self, scheduler: FakeScheduler, tenants: None): | ||
for task_name, entry in scheduler.schedule.items(): | ||
scheduler.apply_entry(entry) | ||
|
||
schemas = ( | ||
entry.tenant_schemas | ||
or get_tenant_model().objects.values_list( | ||
"schema_name", flat=True | ||
) | ||
) | ||
|
||
for schema_name in schemas: | ||
assert (schema_name, entry) in scheduler._sent | ||
|
||
scheduler._sent.clear() | ||
|
||
|
||
@COMMON_PARAMETERS | ||
class TestTenantAwarePersistentScheduler: | ||
"""This is mostly to test that the serialization of `TenantAwareSchedulerEntry`s works | ||
This is because `PersistentScheduler`'s `schedule` property does a dynamic lookup on the `shelve` db, | ||
which forces pickling/unpickling of the entries. | ||
""" | ||
|
||
@fixture | ||
def scheduler(self, app: CeleryApp) -> TenantAwarePersistentScheduler: | ||
with NamedTemporaryFile() as file: | ||
yield TenantAwarePersistentScheduler( | ||
app, schedule_filename=str(file.name) | ||
) | ||
|
||
def test_schedule_setup_properly( | ||
self, | ||
scheduler: TenantAwarePersistentScheduler, | ||
config: Mapping[str, ScheduledEntryConfig], | ||
): | ||
for key, config in config.items(): | ||
assert key in scheduler.schedule | ||
entry = scheduler.schedule[key] | ||
|
||
assert entry.task == config["task"] | ||
assert entry.schedule == schedules.crontab(minute="*") | ||
assert entry.tenant_schemas == config.get("tenant_schemas", None) |