diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index bd47d87..66e2cac 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -1,12 +1,14 @@ # coding=utf-8 -from celery.task import task as base_task, Task -from celery import current_app -import djcelery_transactions.transaction_signals from django.db import transaction +from django.db.transaction import get_connection from functools import partial import threading -from django.db.transaction import get_connection +from celery import current_app +from celery.task import task as base_task, Task +from djcelery_transactions.settings import EXEC_ON_SAVEPOINT +import djcelery_transactions.transaction_signals + # Thread-local data (task queue). _thread_data = threading.local() @@ -72,6 +74,9 @@ def _send_tasks(**kwargs): Called after a transaction is committed or we leave a transaction management block in which no changes were made (effectively a commit). """ + sid = kwargs.pop('sid', None) + if sid is not None and not EXEC_ON_SAVEPOINT: + return queue = _get_task_queue() while queue: cls, args, kwargs = queue.pop(0) diff --git a/djcelery_transactions/settings.py b/djcelery_transactions/settings.py new file mode 100644 index 0000000..5543fc0 --- /dev/null +++ b/djcelery_transactions/settings.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +''' +Created on Jan 8, 2015 + +@author: jakob +''' +from django.conf import settings + +# Whether or not to execute tasks on savepoint commits +EXEC_ON_SAVEPOINT = getattr(settings, 'DJCELERY_TRANSACTIONS_EXEC_ON_SAVEPOINT', True) \ No newline at end of file diff --git a/djcelery_transactions/transaction_signals.py b/djcelery_transactions/transaction_signals.py index d1ae2de..3722b63 100644 --- a/djcelery_transactions/transaction_signals.py +++ b/djcelery_transactions/transaction_signals.py @@ -39,7 +39,7 @@ class TransactionSignals(object): """A container for the transaction signals.""" def __init__(self): - self.post_commit = Signal() + self.post_commit = Signal(providing_args=['sid']) self.post_rollback = Signal() @@ -68,7 +68,7 @@ def __patched__exit__(self, exc_type, exc_value, traceback): if sid is not None: try: connection.savepoint_commit(sid) - transaction.signals.post_commit.send(None) + transaction.signals.post_commit.send(None, sid=sid) except DatabaseError: try: connection.savepoint_rollback(sid) @@ -83,7 +83,7 @@ def __patched__exit__(self, exc_type, exc_value, traceback): # Commit transaction try: connection.commit() - transaction.signals.post_commit.send(None) + transaction.signals.post_commit.send(None, sid=None) except DatabaseError: try: connection.rollback()