Skip to content

Commit 2237866

Browse files
committed
Address concurrency issue on get_or_set
Pulled in fix (with some changes) from sebleier#179
1 parent 9f6b06a commit 2237866

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

redis_cache/backends/base.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from functools import wraps
2+
import time
3+
import uuid
24

35
from django.core.cache.backends.base import (
46
BaseCache, DEFAULT_TIMEOUT, InvalidCacheBackendError,
@@ -442,34 +444,42 @@ def get_or_set(
442444
lock_key = "__lock__" + key
443445
fresh_key = "__fresh__" + key
444446

445-
is_fresh = self._get(client, fresh_key)
446447
value = self._get(client, key)
447-
448+
is_fresh = self._get(client, fresh_key)
448449
if is_fresh:
449450
return value
450451

451-
timeout = self.get_timeout(timeout)
452-
lock = self.lock(lock_key, timeout=lock_timeout)
453-
454-
acquired = lock.acquire(blocking=False)
452+
fresh_timeout = self.get_timeout(timeout)
453+
key_timeout = None if stale_cache_timeout is None else fresh_timeout + stale_cache_timeout
455454

456-
if acquired:
457-
try:
458-
value = func()
459-
except Exception:
460-
raise
455+
token = uuid.uuid1().hex
456+
lock = self.lock(lock_key, timeout=lock_timeout)
457+
acquired = lock.acquire(blocking=False, token=token)
458+
459+
while True:
460+
if acquired:
461+
try:
462+
value = func()
463+
except Exception:
464+
raise
465+
else:
466+
pipeline = client.pipeline()
467+
pipeline.set(key, self.prep_value(value), key_timeout)
468+
pipeline.set(fresh_key, 1, fresh_timeout)
469+
pipeline.execute()
470+
return value
471+
finally:
472+
lock.release()
473+
elif value is None:
474+
time.sleep(lock.sleep)
475+
value = self._get(client, key)
476+
if value is None:
477+
# If there is no value present yet, try to acquire the
478+
# lock again (maybe the other thread died for some reason
479+
# and we should try to compute the value instead).
480+
acquired = lock.acquire(blocking=False, token=token)
461481
else:
462-
key_timeout = (
463-
None if stale_cache_timeout is None else timeout + stale_cache_timeout
464-
)
465-
pipeline = client.pipeline()
466-
pipeline.set(key, self.prep_value(value), key_timeout)
467-
pipeline.set(fresh_key, 1, timeout)
468-
pipeline.execute()
469-
finally:
470-
lock.release()
471-
472-
return value
482+
return value
473483

474484
def _reinsert_keys(self, client):
475485
keys = list(client.scan_iter(match='*'))

tests/testapp/tests/base_tests.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,10 @@ def expensive_function():
511511
self.assertEqual(expensive_function.num_calls, 2)
512512
self.assertEqual(value, 42)
513513

514+
def test_get_or_set_with_none_value(self):
515+
value = self.cache.get_or_set('key', lambda: None, 1, None, 1)
516+
assert value is None
517+
514518
def test_get_or_set_serving_from_stale_value(self):
515519

516520
def expensive_function(x):
@@ -537,6 +541,8 @@ def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_ti
537541
thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1))
538542
thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1))
539543
thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1))
544+
thread_5 = threading.Thread(target=thread_worker, args=(5, None, 1, None, 1))
545+
thread_6 = threading.Thread(target=thread_worker, args=(6, 'g', 1, None, 1))
540546

541547
# First thread should complete and return its value
542548
thread_0.start() # t = 0, valid from t = .5 - 1.5, stale from t = 1.5 - 2.5
@@ -556,19 +562,31 @@ def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_ti
556562
# before the first thread's stale cache has expired.
557563
time.sleep(.25) # t = 2
558564
thread_4.start()
565+
# Sixth thread will start after fourth thread's value is no longer current, explicitly sets
566+
# none as the result
567+
# valid from t = 4 - 5, stale from t = 5 - 6
568+
time.sleep(1.5) # t = 3.5
569+
thread_5.start()
570+
# Seventh thread will start after sixth value has cached - should get None
571+
time.sleep(1) # t = 4.5
572+
thread_6.start()
559573

560574
thread_0.join()
561575
thread_1.join()
562576
thread_2.join()
563577
thread_3.join()
564578
thread_4.join()
579+
thread_5.join()
580+
thread_6.join()
565581

566582
self.assertEqual(results, {
567583
0: 'a',
568-
1: None,
584+
1: 'a',
569585
2: 'a',
570586
3: 'd',
571-
4: 'a'
587+
4: 'a',
588+
5: None,
589+
6: None,
572590
})
573591

574592
def assertMaxConnection(self, cache, max_num):

0 commit comments

Comments
 (0)