Skip to content

Commit 4fe27bb

Browse files
committed
feat: add the ability to use cooldowns inline via await cooldown.increment()
1 parent 2fb7b39 commit 4fe27bb

File tree

8 files changed

+62
-7
lines changed

8 files changed

+62
-7
lines changed

cooldowns/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@
5151
"get_all_cooldowns",
5252
)
5353

54-
__version__ = "2.1.0"
54+
__version__ = "2.2.0"
5555
VersionInfo = namedtuple("VersionInfo", "major minor micro releaselevel serial")
56-
version_info = VersionInfo(major=2, minor=1, micro=0, releaselevel="final", serial=0)
56+
version_info = VersionInfo(major=2, minor=2, micro=0, releaselevel="final", serial=0)

cooldowns/cooldown.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,16 @@ def __init__(
245245
if cooldown_id:
246246
utils.shared_cooldown_refs[cooldown_id] = self
247247

248+
async def increment(self, *args, **kwargs) -> "Cooldown":
249+
"""Inline equivalent to using async with"""
250+
if not self._clean_task:
251+
self._clean_task = asyncio.create_task(self._keep_buckets_clear())
252+
253+
last_bucket = await self.get_bucket(*args, **kwargs)
254+
bucket: TP = self._get_cooldown_for_bucket(last_bucket)
255+
await bucket.increment()
256+
return self
257+
248258
async def __aenter__(self) -> "Cooldown":
249259
if not self._clean_task:
250260
self._clean_task = asyncio.create_task(self._keep_buckets_clear())
@@ -442,7 +452,9 @@ def __repr__(self) -> str:
442452
return f"Cooldown(limit={self.limit}, time_period={self.time_period}, func={self._func})"
443453

444454
@property
445-
def bucket(self) -> Union[CooldownBucketProtocol, AsyncCooldownBucketProtocol, CallableT]:
455+
def bucket(
456+
self,
457+
) -> Union[CooldownBucketProtocol, AsyncCooldownBucketProtocol, CallableT]:
446458
"""Returns the underlying bucket to process cooldowns against."""
447459
return self._bucket
448460

cooldowns/cooldown_times_per.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __init__(
4646
def __repr__(self):
4747
return f"<CooldownTimesPer(limit={self.limit}, current={self.current}, time_period={self.time_period})>"
4848

49-
async def __aenter__(self) -> "CooldownTimesPer":
49+
async def increment(self) -> "CooldownTimesPer":
50+
"""Inline equivalent to using async with"""
5051
if self.current == 0:
5152
raise CallableOnCooldown(
5253
self._cooldown.func, self._cooldown, self.next_reset
@@ -61,6 +62,9 @@ async def __aenter__(self) -> "CooldownTimesPer":
6162

6263
return self
6364

65+
async def __aenter__(self) -> "CooldownTimesPer":
66+
return await self.increment()
67+
6468
async def __aexit__(self, *_) -> None:
6569
...
6670

cooldowns/protocols/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .bucket import CooldownBucketProtocol, AsyncCooldownBucketProtocol, CallableT
22

3-
__all__ = ("CooldownBucketProtocol", "AsyncCooldownBucketProtocol","CallableT")
3+
__all__ = ("CooldownBucketProtocol", "AsyncCooldownBucketProtocol", "CallableT")

cooldowns/protocols/bucket.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
CallableT = Callable[..., Any] | Coroutine[Any, Any, Any]
44

5+
56
class CooldownBucketProtocol(Protocol):
67
"""CooldownBucketProtocol implementation Protocol."""
78

cooldowns/static_times_per.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def get_next_reset(self, now: datetime.datetime) -> datetime.datetime:
6363
possible_options = [self.next_datetime(now, t) for t in self._reset_times]
6464
return min(possible_options)
6565

66-
async def __aenter__(self) -> StaticTimesPer:
66+
async def increment(self) -> "StaticTimesPer":
67+
"""Inline equivalent to using async with"""
6768
if self.current == 0:
6869
raise CallableOnCooldown(
6970
self._cooldown.func, self._cooldown, self.next_reset
@@ -78,3 +79,6 @@ async def __aenter__(self) -> StaticTimesPer:
7879
self.loop.call_later((reset - now).total_seconds(), self._reset_invoke)
7980

8081
return self
82+
83+
async def __aenter__(self) -> StaticTimesPer:
84+
return await self.increment()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "function-cooldowns"
3-
version = "2.1.0"
3+
version = "2.2.0"
44
description = "A simplistic decorator based approach to rate limiting function calls."
55
authors = [{name = "Skelmis", email="[email protected]"}]
66
license = "UNLICENSE"
@@ -31,6 +31,7 @@ attrs = "^25.1.0"
3131
pytest = "^8.3.4"
3232
freezegun = "^1.5.1"
3333
pytest-asyncio = "^0.25.3"
34+
black = "^25.1.0"
3435

3536
[build-system]
3637
requires = ["poetry-core"]

tests/test_cooldown.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ async def test_func(*args, **kwargs):
136136

137137
await test_func(2)
138138

139+
139140
@pytest.mark.asyncio
140141
async def test_custom_callable_as_bucket():
141142
def first_arg(*args):
@@ -152,6 +153,7 @@ async def test_func(*args, **kwargs):
152153

153154
await test_func(2)
154155

156+
155157
@pytest.mark.asyncio
156158
async def test_async_custom_callable_as_bucket():
157159
async def first_arg(*args):
@@ -168,6 +170,7 @@ async def test_func(*args, **kwargs):
168170

169171
await test_func(2)
170172

173+
171174
@pytest.mark.asyncio
172175
async def test_async_bucket_process():
173176
class CustomBucket(Enum):
@@ -392,3 +395,33 @@ async def test():
392395
assert _cooldown.get_cooldown_times_per(await _cooldown.get_bucket()) is None
393396
await test()
394397
assert _cooldown.get_cooldown_times_per(await _cooldown.get_bucket()) is not None
398+
399+
400+
@pytest.mark.asyncio
401+
async def test_inline_cooldowns():
402+
# Can be called once every second
403+
# Default bucket is ALL arguments
404+
@cooldown(1, 1, bucket=CooldownBucket.all)
405+
async def test_func(*args, **kwargs) -> (tuple, dict):
406+
return args, kwargs
407+
408+
_cooldown: Cooldown = getattr(test_func, "_cooldowns")[0]
409+
# Call it once, so its on cooldown after this
410+
data = await test_func(1, two=2)
411+
assert data == ((1,), {"two": 2})
412+
413+
with pytest.raises(CallableOnCooldown):
414+
# Since this uses the same arguments
415+
# as the previous call, it comes under
416+
# the same bucket, and thus gets rate-limited
417+
await _cooldown.increment(1, two=2)
418+
419+
# Shouldn't error as it comes under the
420+
# bucket _HashableArguments(1) rather then
421+
# the bucket _HashableArguments(1, two=2)
422+
# which are completely different
423+
await test_func(1)
424+
425+
await _cooldown.increment(1, two=3)
426+
with pytest.raises(CallableOnCooldown):
427+
await _cooldown.increment(1, two=3)

0 commit comments

Comments
 (0)