Skip to content

Commit bc62815

Browse files
committed
Add cache lock in Shcemafile.get_or_create
Utilize a redis cache lock when getting or creating a new SchemaFile in order to prevent race conditions. Implement tests. Refs. TS-2457
1 parent ea4437f commit bc62815

File tree

2 files changed

+67
-22
lines changed

2 files changed

+67
-22
lines changed

django/thunderstore/schema_server/models/file.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from django.db import models
99
from django.utils import timezone
1010

11+
from thunderstore.cache.utils import get_cache
1112
from thunderstore.core.mixins import S3FileMixin
1213

14+
cache = get_cache("legacy")
15+
CACHE_LOCK_TIMEOUT = 30
16+
1317

1418
def get_schema_file_path(_, filename: str) -> str:
1519
return f"schema/sha256/{filename}"
@@ -39,27 +43,30 @@ def get_or_create(cls, content: bytes) -> "SchemaFile":
3943
hash.update(content)
4044
checksum = hash.hexdigest()
4145

42-
if existing := cls.objects.filter(checksum_sha256=checksum).first():
43-
return existing
46+
lock_key = f"lock.schemafile.{checksum}"
47+
with cache.lock(lock_key, timeout=CACHE_LOCK_TIMEOUT, blocking_timeout=None):
48+
if existing := cls.objects.filter(checksum_sha256=checksum).first():
49+
return existing
50+
51+
gzipped = io.BytesIO()
52+
with gzip.GzipFile(fileobj=gzipped, mode="wb") as f:
53+
f.write(content)
54+
timestamp = timezone.now()
4455

45-
gzipped = io.BytesIO()
46-
with gzip.GzipFile(fileobj=gzipped, mode="wb") as f:
47-
f.write(content)
48-
timestamp = timezone.now()
56+
file = ContentFile(
57+
# TODO: This is immediately passed to BytesIO again, meaning
58+
# we're just wasting memory. Find a way to pass this to
59+
# the Django model without the inefficiency.
60+
gzipped.getvalue(),
61+
name=f"{checksum}.json.gz",
62+
)
4963

50-
file = ContentFile(
51-
# TODO: This is immediately passed to BytesIO again, meaning
52-
# we're just wasting memory. Find a way to pass this to
53-
# the Django model without the inefficiency.
54-
gzipped.getvalue(),
55-
name=f"{checksum}.json.gz",
56-
)
57-
return cls.objects.create(
58-
data=file,
59-
content_type="application/json",
60-
content_encoding="gzip",
61-
last_modified=timestamp,
62-
checksum_sha256=checksum,
63-
file_size=len(content),
64-
gzip_size=file.size,
65-
)
64+
return cls.objects.create(
65+
data=file,
66+
content_type="application/json",
67+
content_encoding="gzip",
68+
last_modified=timestamp,
69+
checksum_sha256=checksum,
70+
file_size=len(content),
71+
gzip_size=file.size,
72+
)

django/thunderstore/schema_server/models/tests/test_file.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import concurrent.futures
12
import time
23
from datetime import timedelta
34
from hashlib import sha256
5+
from unittest.mock import MagicMock, patch
46

57
import pytest
68
from django.utils import timezone
@@ -92,3 +94,39 @@ def test_schema_server_file_get_or_create_deduplication():
9294
file_b = SchemaFile.get_or_create(test_data_b)
9395
assert file_a != file_b
9496
assert SchemaFile.objects.count() == 2
97+
98+
99+
@pytest.mark.django_db
100+
def test_get_or_create_schema_file_cache_lock_acquired_and_released():
101+
test_data = b"Hello world!"
102+
103+
mock_lock = MagicMock()
104+
mock_lock.__enter__.return_value = True
105+
mock_lock.__exit__.return_value = None
106+
107+
with patch("django_redis.cache.RedisCache.lock", return_value=mock_lock):
108+
file = SchemaFile.get_or_create(test_data)
109+
110+
assert file is not None
111+
mock_lock.__enter__.assert_called_once()
112+
mock_lock.__exit__.assert_called_once()
113+
114+
115+
@pytest.mark.django_db(transaction=True)
116+
def test_get_or_create_schema_file_parallel():
117+
test_data = b"Hello world!"
118+
119+
def call_get_or_create():
120+
return SchemaFile.get_or_create(test_data)
121+
122+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
123+
futures = [executor.submit(call_get_or_create) for _ in range(3)]
124+
results = [f.result() for f in futures]
125+
126+
pks = {file.pk for file in results}
127+
assert len(pks) == 1
128+
129+
assert (
130+
SchemaFile.objects.filter(checksum_sha256=results[0].checksum_sha256).count()
131+
== 1
132+
)

0 commit comments

Comments
 (0)