Skip to content

Commit

Permalink
Add support to celery beat through custom Schedulers (#65)
Browse files Browse the repository at this point in the history
* Include `.idea` folder to `.gitignore`

* Add `TenantAwareSchedulerMixin` and subclasses

The subclasses of this mixin (provided they're mixed up with `PersistentScheduler` or `Scheduler`) call the task's `delay` function inside of each tenant's context

* Document usage of schedulers in beat integration

* Apply suggestions from code review
  • Loading branch information
edg956 authored Mar 30, 2023
1 parent 3e52d8c commit a2bbd18
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ coverage.xml
# Sphinx documentation
docs/_build/

# IDEs
.idea
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ def some_task():
Celery beat integration
-----------------------

This package does not provide support for scheduling periodic tasks inside given schema. Instead, you can use `{django_tenants,django_tenants_schemas}.utils.{get_tenant_model,tenant_context}` methods to launch given tasks within specific tenant.
In order to run celery beat tasks in a multi-tenant environment, you've got two options:
- Use a dispatching task that will send a task for each tenant
- Use a custom scheduler

Let's say that you would like to run a `reset_remaining_jobs` tasks periodically, for every tenant that you have. Instead of scheduling the task for each schema separately, you can schedule one dispatcher task that will iterate over all schemas and send specific task for each schema you want, instead:
i.e: Let's say that you would like to run a `reset_remaining_jobs` tasks periodically, for every tenant that you have.

### Dispatcher task pattern
You can schedule one dispatcher task that will iterate over all schemas and send that task within the schema's context:

```python
from django_tenants.utils import get_tenant_model, tenant_context
Expand All @@ -124,6 +129,27 @@ The `reset_remaining_jobs_in_all_schemas` task (called the dispatch task) should

That way you have full control over which schemas the task should be scheduled in.


### Custom scheduler
You are using the standard `Scheduler` or `PersistentScheduler` classes provided by `celery`, you can transition to using this package's `TenantAwareScheduler` or `TenantAwarePersistentScheduler` classes. You should then specify the scheduler you want to use in your invocation to `beat`. i.e:

```bash
celery -A proj beat --scheduler=tenant_schemas_celery.scheduler.TenantAwareScheduler
```

#### Caveats
`TenantAwareSchedulerMixin` uses a subclass of `SchedulerEntry` that allows the user to provide specific schemas to run a task on. This might prove useful if you have a task you only want to run in the `public` schema or to a subset of your tenants. In order to use set that, you must configure `tenant_schemas` in the tasks definition as such:

```python
app.conf.beat_schedule = {
"my-task": {
"task": "myapp.tasks.my_task",
"schedule": schedules.crontab(minute="*/5"),
"tenant_schemas": ["public"]
}
}
```

Compatibility changes
=====================

Expand Down
115 changes: 115 additions & 0 deletions tenant_schemas_celery/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import logging
from typing import List, Optional

from celery.beat import PersistentScheduler, ScheduleEntry, Scheduler
from django_tenants.utils import get_tenant_model, tenant_context, get_public_schema_name

logger = logging.getLogger(__name__)


class TenantAwareScheduleEntry(ScheduleEntry):
tenant_schemas: Optional[List[str]] = None

def __init__(self, *args, **kwargs):
if args:
# Unpickled from database
self.tenant_schemas = args[-1]
else:
# Initialized from code
self.tenant_schemas = kwargs.pop("tenant_schemas", None)

super().__init__(*args, **kwargs)

def update(self, other):
"""Update values from another entry.
Will only update `tenant_schemas` and "editable" fields:
``task``, ``schedule``, ``args``, ``kwargs``, ``options``.
"""
vars(self).update(
{
"task": other.task,
"schedule": other.schedule,
"args": other.args,
"kwargs": other.kwargs,
"options": other.options,
"tenant_schemas": other.tenant_schemas,
}
)

def __reduce__(self):
"""Needed for Pickle serialization"""
return self.__class__, (
self.name,
self.task,
self.last_run_at,
self.total_run_count,
self.schedule,
self.args,
self.kwargs,
self.options,
self.tenant_schemas,
)

def editable_fields_equal(self, other):
for attr in (
"task",
"args",
"kwargs",
"options",
"schedule",
"tenant_schemas",
):
if getattr(self, attr) != getattr(other, attr):
return False
return True


class TenantAwareSchedulerMixin:
Entry = TenantAwareScheduleEntry

def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None):
"""
See https://github.com/celery/celery/blob/c571848023be732a1a11d46198cf831a522cfb54/celery/beat.py#L277
"""

tenants = get_tenant_model().objects.all()

if entry.tenant_schemas is None:
tenants = tenants.exclude(schema_name=get_public_schema_name())
else:
tenants = tenants.filter(schema_name__in=entry.tenant_schemas)

logger.info(
"TenantAwareScheduler: Sending due task %s (%s) to %s tenants",
entry.name,
entry.task,
"all" if entry.tenant_schemas is None else str(len(tenants)),
)

for tenant in tenants:
with tenant_context(tenant):
logger.debug(
"Sending due task %s (%s) to tenant %s",
entry.name,
entry.task,
tenant.name,
)
try:
result = self.apply_async(
entry, producer=producer, advance=False
)
except Exception as exc:
logger.exception(exc)
else:
logger.debug("%s sent. id->%s", entry.task, result.id)


class TenantAwareScheduler(TenantAwareSchedulerMixin, Scheduler):
pass


class TenantAwarePersistentScheduler(
TenantAwareSchedulerMixin, PersistentScheduler
):
pass
147 changes: 147 additions & 0 deletions tenant_schemas_celery/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from tempfile import NamedTemporaryFile
from typing import Any, List, Mapping, Optional, Tuple, TypedDict

from celery import schedules, uuid
from celery.beat import Scheduler
from django.db import connection
from django_tenants.utils import get_tenant_model, schema_context, get_public_schema_name
from pytest import fixture, mark
from tenant_schemas_celery.app import CeleryApp

from .scheduler import (
TenantAwarePersistentScheduler,
TenantAwareSchedulerMixin,
)


class FakeScheduler(TenantAwareSchedulerMixin, Scheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._sent: List[Tuple[str, TenantAwareSchedulerMixin.Entry]] = []

def apply_async(self, entry, producer=None, advance=True, **kwargs):
self._sent.append((connection.schema_name, entry))
return self.app.AsyncResult(uuid())


class ScheduledEntryConfig(TypedDict, total=False):
task: str
schedule: int
args: Optional[Tuple[Any, ...]]
kwargs: Optional[Mapping[str, Any]]
options: Optional[Mapping[str, Any]]
tenant_schemas: Optional[List[str]]


COMMON_PARAMETERS = mark.parametrize(
"config",
(
{
"test_task": {
"task": "test_task",
"schedule": schedules.crontab(minute="*"),
}
},
{
"test_tenant_specific_task": {
"task": "tenant_specific_task",
"schedule": schedules.crontab(minute="*"),
"tenant_schemas": ["tenant1"],
}
},
{
"test_tenant_specific_task": {
"task": "tenant_specific_task",
"schedule": schedules.crontab(minute="*"),
"tenant_schemas": ["tenant1", "tenant2"],
},
"test_generic_task": {
"task": "generic_task",
"schedule": schedules.crontab(minute="*"),
},
},
),
)


@fixture
def app(config: Mapping[str, ScheduledEntryConfig]) -> CeleryApp:
app = CeleryApp("test_app", set_as_current=False)
app.conf.beat_schedule = config
return app


@COMMON_PARAMETERS
class TestTenantAwareSchedulerMixin:
@fixture
def scheduler(self, app: CeleryApp) -> FakeScheduler:
return FakeScheduler(app)

def test_schedule_setup_properly(
self,
scheduler: FakeScheduler,
config: Mapping[str, ScheduledEntryConfig],
):
for key, config in config.items():
assert key in scheduler.schedule
entry = scheduler.schedule[key]

assert entry.task == config["task"]
assert entry.schedule == schedules.crontab(minute="*")
assert entry.tenant_schemas == config.get("tenant_schemas", None)

@fixture
def tenants(self) -> None:
with schema_context(get_public_schema_name()):
get_tenant_model().objects.create(
name="Tenant1", schema_name="tenant1"
)
get_tenant_model().objects.create(
name="Tenant2", schema_name="tenant2"
)

@mark.django_db
def test_apply_entry(self, scheduler: FakeScheduler, tenants: None):
for task_name, entry in scheduler.schedule.items():
scheduler.apply_entry(entry)

schemas = (
entry.tenant_schemas
or get_tenant_model().objects.values_list(
"schema_name", flat=True
)
)

for schema_name in schemas:
assert (schema_name, entry) in scheduler._sent

scheduler._sent.clear()


@COMMON_PARAMETERS
class TestTenantAwarePersistentScheduler:
"""This is mostly to test that the serialization of `TenantAwareSchedulerEntry`s works
This is because `PersistentScheduler`'s `schedule` property does a dynamic lookup on the `shelve` db,
which forces pickling/unpickling of the entries.
"""

@fixture
def scheduler(self, app: CeleryApp) -> TenantAwarePersistentScheduler:
with NamedTemporaryFile() as file:
yield TenantAwarePersistentScheduler(
app, schedule_filename=str(file.name)
)

def test_schedule_setup_properly(
self,
scheduler: TenantAwarePersistentScheduler,
config: Mapping[str, ScheduledEntryConfig],
):
for key, config in config.items():
assert key in scheduler.schedule
entry = scheduler.schedule[key]

assert entry.task == config["task"]
assert entry.schedule == schedules.crontab(minute="*")
assert entry.tenant_schemas == config.get("tenant_schemas", None)

0 comments on commit a2bbd18

Please sign in to comment.