Skip to content

Commit 222830d

Browse files
yoggyspre-commit-ci[bot]Lulalaby
authored andcommitted
fix: EntitlementIterator behavior and type-hinting (#2555)
* fix: EntitlementIterator behaviour and type-hinting * style(pre-commit): auto fixes from pre-commit.com hooks * simplify if's * add changelog entry * style(pre-commit): auto fixes from pre-commit.com hooks * revert missclick --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lala Sabathil <[email protected]>
1 parent 38e9d40 commit 222830d

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ These changes are available on the `master` branch, but have not yet been releas
5959
- Fixed `Webhook.send` not including attachment data.
6060
([#2513](https://github.com/Pycord-Development/pycord/pull/2513))
6161

62+
### Fixed
63+
64+
- Fixed `EntitlementIterator` behavior with `limit > 100`.
65+
([#2555](https://github.com/Pycord-Development/pycord/pull/2555))
66+
6267
## [2.6.0] - 2024-07-09
6368

6469
### Added

discord/iterators.py

+56-18
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .types.audit_log import AuditLog as AuditLogPayload
6565
from .types.guild import Guild as GuildPayload
6666
from .types.message import Message as MessagePayload
67+
from .types.monetization import Entitlement as EntitlementPayload
6768
from .types.threads import Thread as ThreadPayload
6869
from .types.user import PartialUser as PartialUserPayload
6970
from .user import User
@@ -988,11 +989,21 @@ def __init__(
988989
self.guild_id = guild_id
989990
self.exclude_ended = exclude_ended
990991

992+
self._filter = None
993+
994+
if self.before and self.after:
995+
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
996+
self._filter = lambda e: int(e["id"]) > self.after.id
997+
elif self.after:
998+
self._retrieve_entitlements = self._retrieve_entitlements_after_strategy
999+
else:
1000+
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
1001+
9911002
self.state = state
9921003
self.get_entitlements = state.http.list_entitlements
9931004
self.entitlements = asyncio.Queue()
9941005

995-
async def next(self) -> BanEntry:
1006+
async def next(self) -> Entitlement:
9961007
if self.entitlements.empty():
9971008
await self.fill_entitlements()
9981009

@@ -1014,30 +1025,57 @@ async def fill_entitlements(self):
10141025
if not self._get_retrieve():
10151026
return
10161027

1028+
data = await self._retrieve_entitlements(self.retrieve)
1029+
1030+
if self._filter:
1031+
data = list(filter(self._filter, data))
1032+
1033+
if len(data) < 100:
1034+
self.limit = 0 # terminate loop
1035+
1036+
for element in data:
1037+
await self.entitlements.put(Entitlement(data=element, state=self.state))
1038+
1039+
async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
1040+
"""Retrieve entitlements and update next parameters."""
1041+
raise NotImplementedError
1042+
1043+
async def _retrieve_entitlements_before_strategy(
1044+
self, retrieve: int
1045+
) -> list[EntitlementPayload]:
1046+
"""Retrieve entitlements using before parameter."""
10171047
before = self.before.id if self.before else None
1018-
after = self.after.id if self.after else None
10191048
data = await self.get_entitlements(
10201049
self.state.application_id,
10211050
before=before,
1022-
after=after,
1023-
limit=self.retrieve,
1051+
limit=retrieve,
10241052
user_id=self.user_id,
10251053
guild_id=self.guild_id,
10261054
sku_ids=self.sku_ids,
10271055
exclude_ended=self.exclude_ended,
10281056
)
1057+
if data:
1058+
if self.limit is not None:
1059+
self.limit -= retrieve
1060+
self.before = Object(id=int(data[-1]["id"]))
1061+
return data
10291062

1030-
if not data:
1031-
# no data, terminate
1032-
return
1033-
1034-
if self.limit:
1035-
self.limit -= self.retrieve
1036-
1037-
if len(data) < 100:
1038-
self.limit = 0 # terminate loop
1039-
1040-
self.after = Object(id=int(data[-1]["id"]))
1041-
1042-
for element in reversed(data):
1043-
await self.entitlements.put(Entitlement(data=element, state=self.state))
1063+
async def _retrieve_entitlements_after_strategy(
1064+
self, retrieve: int
1065+
) -> list[EntitlementPayload]:
1066+
"""Retrieve entitlements using after parameter."""
1067+
after = self.after.id if self.after else None
1068+
data = await self.get_entitlements(
1069+
self.state.application_id,
1070+
after=after,
1071+
limit=retrieve,
1072+
user_id=self.user_id,
1073+
guild_id=self.guild_id,
1074+
sku_ids=self.sku_ids,
1075+
exclude_ended=self.exclude_ended,
1076+
)
1077+
if data:
1078+
if self.limit is not None:
1079+
self.limit -= retrieve
1080+
self.after = Object(id=int(data[-1]["id"]))
1081+
return data

0 commit comments

Comments
 (0)