Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion app/eventyay/api/views/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,8 +1247,14 @@ def confirm(self, request, **kwargs):
return self.retrieve(request, [], **kwargs)

@action(detail=True, methods=['POST'])
@transaction.atomic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting transaction.atomic here doesn't work, because this refund function already handle an PaymentException:

try:
    r.payment_provider.execute_refund(r)
except PaymentException as e:

Read the Avoid catching exceptions inside atomic! note.

def refund(self, request, **kwargs):
payment = self.get_object()
# Acquire row-level lock on payment to prevent concurrent refund race conditions.
# We reuse DRF's filtering & permission logic while issuing a single SELECT ... FOR UPDATE.
queryset = self.filter_queryset(self.get_queryset().select_for_update())
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
payment = get_object_or_404(queryset, **{self.lookup_field: lookup_value})
amount = serializers.DecimalField(max_digits=10, decimal_places=2).to_internal_value(
request.data.get('amount', str(payment.amount))
)
Expand All @@ -1265,6 +1271,7 @@ def refund(self, request, **kwargs):

full_refund_possible = payment.payment_provider.payment_refund_supported(payment)
partial_refund_possible = payment.payment_provider.payment_partial_refund_supported(payment)
# Calculate available amount under lock - prevents TOCTOU race condition
available_amount = payment.amount - payment.refunded_amount

if amount <= 0:
Expand Down
65 changes: 65 additions & 0 deletions src/tests/api/test_orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3989,6 +3989,71 @@ def test_refund_create_invalid_payment(token_client, organizer, event, order):
)
assert resp.status_code == 400

@pytest.mark.django_db(transaction=True)
def test_refund_concurrent_race_condition_prevented(token_client, organizer, event, order):

import threading
from decimal import Decimal
from django.db import connection
from django.test.utils import CaptureQueriesContext
from django.db import connection
from django.test import Client
import json

with scopes_disabled():
payment = order.payments.first()
payment.amount = Decimal('100.00')
payment.state = OrderPayment.PAYMENT_STATE_CONFIRMED
payment.save()

results = []
errors = []

def make_refund_request(amount):
try:
client = Client()
client.defaults['HTTP_AUTHORIZATION'] = token_client.defaults.get('HTTP_AUTHORIZATION')
response = client.post(
f'/api/v1/organizers/{organizer.slug}/events/{event.slug}/orders/{order.code}/payments/{payment.local_id}/refund/',
data=json.dumps({'amount': str(amount)}),
content_type='application/json',
)
results.append({'status': response.status_code, 'amount': amount})
except Exception as e:
errors.append(str(e))
finally:
connection.close()

thread1 = threading.Thread(target=make_refund_request, args=(Decimal('80.00'),))
thread2 = threading.Thread(target=make_refund_request, args=(Decimal('80.00'),))

thread1.start()
thread2.start()

thread1.join(timeout=5)
thread2.join(timeout=5)

assert len(errors) == 0, f"Unexpected errors: {errors}"
assert len(results) == 2, f"Expected 2 results, got {len(results)}"

status_codes = sorted([r['status'] for r in results])

assert status_codes == [200, 400], \
f"Expected one success (200) and one failure (400), got {status_codes}"

with scopes_disabled():
payment.refresh_from_db()
refunds = list(payment.refunds.all())
total_refunded = sum(r.amount for r in refunds)

assert len(refunds) == 1, f"Expected exactly 1 refund object, found {len(refunds)}"
assert total_refunded == Decimal('80.00'), \
f"Expected total refunded = $80.00, got {total_refunded}"

available = payment.amount - total_refunded
assert available == Decimal('20.00'), \
f"Expected available = $20.00, got {available}"


@pytest.mark.django_db
def test_order_delete(token_client, organizer, event, order):
Expand Down