Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions djcelery_transactions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# coding=utf-8
from celery.task import task as base_task, Task
from celery import current_app
import djcelery_transactions.transaction_signals
from django.db import transaction
from django.db.transaction import get_connection
from functools import partial
import threading

from django.db.transaction import get_connection
from celery import current_app
from celery.task import task as base_task, Task
from djcelery_transactions.settings import EXEC_ON_SAVEPOINT
import djcelery_transactions.transaction_signals


# Thread-local data (task queue).
_thread_data = threading.local()
Expand Down Expand Up @@ -72,6 +74,9 @@ def _send_tasks(**kwargs):
Called after a transaction is committed or we leave a transaction
management block in which no changes were made (effectively a commit).
"""
sid = kwargs.pop('sid', None)
if sid is not None and not EXEC_ON_SAVEPOINT:
return
queue = _get_task_queue()
while queue:
cls, args, kwargs = queue.pop(0)
Expand Down
10 changes: 10 additions & 0 deletions djcelery_transactions/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
'''
Created on Jan 8, 2015

@author: jakob
'''
from django.conf import settings

# Whether or not to execute tasks on savepoint commits
EXEC_ON_SAVEPOINT = getattr(settings, 'DJCELERY_TRANSACTIONS_EXEC_ON_SAVEPOINT', True)
6 changes: 3 additions & 3 deletions djcelery_transactions/transaction_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TransactionSignals(object):
"""A container for the transaction signals."""

def __init__(self):
self.post_commit = Signal()
self.post_commit = Signal(providing_args=['sid'])
self.post_rollback = Signal()


Expand Down Expand Up @@ -68,7 +68,7 @@ def __patched__exit__(self, exc_type, exc_value, traceback):
if sid is not None:
try:
connection.savepoint_commit(sid)
transaction.signals.post_commit.send(None)
transaction.signals.post_commit.send(None, sid=sid)
except DatabaseError:
try:
connection.savepoint_rollback(sid)
Expand All @@ -83,7 +83,7 @@ def __patched__exit__(self, exc_type, exc_value, traceback):
# Commit transaction
try:
connection.commit()
transaction.signals.post_commit.send(None)
transaction.signals.post_commit.send(None, sid=None)
except DatabaseError:
try:
connection.rollback()
Expand Down