64
64
from .types .audit_log import AuditLog as AuditLogPayload
65
65
from .types .guild import Guild as GuildPayload
66
66
from .types .message import Message as MessagePayload
67
+ from .types .monetization import Entitlement as EntitlementPayload
67
68
from .types .threads import Thread as ThreadPayload
68
69
from .types .user import PartialUser as PartialUserPayload
69
70
from .user import User
@@ -988,11 +989,21 @@ def __init__(
988
989
self .guild_id = guild_id
989
990
self .exclude_ended = exclude_ended
990
991
992
+ self ._filter = None
993
+
994
+ if self .before and self .after :
995
+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
996
+ self ._filter = lambda e : int (e ["id" ]) > self .after .id
997
+ elif self .after :
998
+ self ._retrieve_entitlements = self ._retrieve_entitlements_after_strategy
999
+ else :
1000
+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
1001
+
991
1002
self .state = state
992
1003
self .get_entitlements = state .http .list_entitlements
993
1004
self .entitlements = asyncio .Queue ()
994
1005
995
- async def next (self ) -> BanEntry :
1006
+ async def next (self ) -> Entitlement :
996
1007
if self .entitlements .empty ():
997
1008
await self .fill_entitlements ()
998
1009
@@ -1014,30 +1025,57 @@ async def fill_entitlements(self):
1014
1025
if not self ._get_retrieve ():
1015
1026
return
1016
1027
1028
+ data = await self ._retrieve_entitlements (self .retrieve )
1029
+
1030
+ if self ._filter :
1031
+ data = list (filter (self ._filter , data ))
1032
+
1033
+ if len (data ) < 100 :
1034
+ self .limit = 0 # terminate loop
1035
+
1036
+ for element in data :
1037
+ await self .entitlements .put (Entitlement (data = element , state = self .state ))
1038
+
1039
+ async def _retrieve_entitlements (self , retrieve ) -> list [Entitlement ]:
1040
+ """Retrieve entitlements and update next parameters."""
1041
+ raise NotImplementedError
1042
+
1043
+ async def _retrieve_entitlements_before_strategy (
1044
+ self , retrieve : int
1045
+ ) -> list [EntitlementPayload ]:
1046
+ """Retrieve entitlements using before parameter."""
1017
1047
before = self .before .id if self .before else None
1018
- after = self .after .id if self .after else None
1019
1048
data = await self .get_entitlements (
1020
1049
self .state .application_id ,
1021
1050
before = before ,
1022
- after = after ,
1023
- limit = self .retrieve ,
1051
+ limit = retrieve ,
1024
1052
user_id = self .user_id ,
1025
1053
guild_id = self .guild_id ,
1026
1054
sku_ids = self .sku_ids ,
1027
1055
exclude_ended = self .exclude_ended ,
1028
1056
)
1057
+ if data :
1058
+ if self .limit is not None :
1059
+ self .limit -= retrieve
1060
+ self .before = Object (id = int (data [- 1 ]["id" ]))
1061
+ return data
1029
1062
1030
- if not data :
1031
- # no data, terminate
1032
- return
1033
-
1034
- if self .limit :
1035
- self .limit -= self .retrieve
1036
-
1037
- if len (data ) < 100 :
1038
- self .limit = 0 # terminate loop
1039
-
1040
- self .after = Object (id = int (data [- 1 ]["id" ]))
1041
-
1042
- for element in reversed (data ):
1043
- await self .entitlements .put (Entitlement (data = element , state = self .state ))
1063
+ async def _retrieve_entitlements_after_strategy (
1064
+ self , retrieve : int
1065
+ ) -> list [EntitlementPayload ]:
1066
+ """Retrieve entitlements using after parameter."""
1067
+ after = self .after .id if self .after else None
1068
+ data = await self .get_entitlements (
1069
+ self .state .application_id ,
1070
+ after = after ,
1071
+ limit = retrieve ,
1072
+ user_id = self .user_id ,
1073
+ guild_id = self .guild_id ,
1074
+ sku_ids = self .sku_ids ,
1075
+ exclude_ended = self .exclude_ended ,
1076
+ )
1077
+ if data :
1078
+ if self .limit is not None :
1079
+ self .limit -= retrieve
1080
+ self .after = Object (id = int (data [- 1 ]["id" ]))
1081
+ return data
0 commit comments