Skip to content

Commit 4ae7f4a

Browse files
authored
[commands] command typehint updates (#425)
* Add ability to parse Unions into ext.commands * pipe was 3.10, not 3.8 * Fix optional parsing when optional is the last argument * Add Annotated support * Run black * revamp errors with proper useful messages and details on objects * update changelog with changes * run black
1 parent 9628522 commit 4ae7f4a

File tree

3 files changed

+186
-28
lines changed

3 files changed

+186
-28
lines changed

docs/changelog.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ Master
2424
- Fix websocket reconnection event.
2525
- Fix another websocket reconnect issue where it tried to decode nonexistent headers.
2626

27+
- ext.commands
28+
- Additions
29+
- Added support for the following typing constructs in command signatures:
30+
- ``Union[A, B]`` / ``A | B``
31+
- ``Optional[T]`` / ``T | None``
32+
- ``Annotated[T, converter]`` (accessible through the ``typing_extensions`` module on older python versions)
33+
2734

2835
2.7.0
2936
======

twitchio/ext/commands/core.py

Lines changed: 141 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import itertools
2929
import copy
30+
import types
3031
from typing import Any, Union, Optional, Callable, Awaitable, Tuple, TYPE_CHECKING, List, Type, Set, TypeVar
3132
from typing_extensions import Literal
3233

@@ -36,12 +37,32 @@
3637
from . import builtin_converter
3738

3839
if TYPE_CHECKING:
40+
import sys
41+
3942
from twitchio import Message, Chatter, PartialChatter, Channel, User, PartialUser
4043
from . import Cog, Bot
4144
from .stringparser import StringParser
45+
46+
if sys.version_info >= (3, 10):
47+
UnionT = Union[types.UnionType, Union]
48+
else:
49+
UnionT = Union
50+
51+
4252
__all__ = ("Command", "command", "Group", "Context", "cooldown")
4353

4454

55+
class EmptyArgumentSentinel:
56+
def __repr__(self) -> str:
57+
return "<EMPTY>"
58+
59+
def __eq__(self, __value: object) -> bool:
60+
return False
61+
62+
63+
EMPTY = EmptyArgumentSentinel()
64+
65+
4566
def _boolconverter(param: str):
4667
param = param.lower()
4768
if param in {"yes", "y", "1", "true", "on"}:
@@ -114,40 +135,127 @@ def full_name(self) -> str:
114135
return self._name
115136
return f"{self.parent.full_name} {self._name}"
116137

117-
def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Union[Callable[..., Any]]:
138+
def _is_optional_argument(self, converter: Any):
139+
return (getattr(converter, "__origin__", None) is Union or isinstance(converter, types.UnionType)) and type(
140+
None
141+
) in converter.__args__
142+
143+
def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Context, str], Any]:
144+
# print(type(converter), converter.__args__)
145+
146+
args = converter.__args__ # type: ignore # pyright doesnt like this
147+
148+
async def _resolve(context: Context, arg: str) -> Any:
149+
t = EMPTY
150+
last = None
151+
152+
for original in args:
153+
underlying = self._resolve_converter(name, original)
154+
155+
try:
156+
t: Any = underlying(context, arg)
157+
if inspect.iscoroutine(t):
158+
t = await t
159+
160+
break
161+
except Exception as l:
162+
last = l
163+
t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back
164+
continue
165+
166+
if t is EMPTY:
167+
raise UnionArgumentParsingFailed(name, args)
168+
169+
return t
170+
171+
return _resolve
172+
173+
def resolve_optional_callback(self, name: str, converter: Any) -> Callable[[Context, str], Any]:
174+
underlying = self._resolve_converter(name, converter.__args__[0])
175+
176+
async def _resolve(context: Context, arg: str) -> Any:
177+
try:
178+
t: Any = underlying(context, arg)
179+
if inspect.iscoroutine(t):
180+
t = await t
181+
182+
except Exception:
183+
return EMPTY # instruct the parser to roll back and ignore this argument
184+
185+
return t
186+
187+
return _resolve
188+
189+
def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]:
118190
if (
119191
isinstance(converter, type)
120192
and converter.__module__.startswith("twitchio")
121193
and converter in builtin_converter._mapping
122194
):
123-
return builtin_converter._mapping[converter]
124-
return converter
195+
return self._convert_builtin_type(name, converter, builtin_converter._mapping[converter])
196+
197+
elif converter is bool:
198+
converter = self._convert_builtin_type(name, bool, _boolconverter)
199+
200+
elif converter in (str, int):
201+
converter = self._convert_builtin_type(name, converter, converter) # type: ignore
202+
203+
elif self._is_optional_argument(converter):
204+
return self.resolve_optional_callback(name, converter)
205+
206+
elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union:
207+
return self.resolve_union_callback(name, converter) # type: ignore
208+
209+
elif hasattr(converter, "__metadata__"): # Annotated
210+
annotated = converter.__metadata__ # type: ignore
211+
return self._resolve_converter(name, annotated[0])
212+
213+
return converter # type: ignore
214+
215+
def _convert_builtin_type(
216+
self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]]
217+
) -> Callable[[Context, str], Awaitable[Any]]:
218+
async def resolve(_, arg: str) -> Any:
219+
try:
220+
t = converter(arg)
221+
222+
if inspect.iscoroutine(t):
223+
t = await t
224+
225+
return t
226+
except Exception as e:
227+
raise ArgumentParsingFailed(
228+
f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`",
229+
original=e,
230+
argname=arg_name,
231+
expected=original,
232+
) from e
233+
234+
return resolve
125235

126236
async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any:
127237
converter = param.annotation
238+
128239
if converter is param.empty:
129240
if param.default in (param.empty, None):
130241
converter = str
131242
else:
132243
converter = type(param.default)
133-
true_converter = self._resolve_converter(converter)
244+
245+
true_converter = self._resolve_converter(param.name, converter)
134246

135247
try:
136-
if true_converter in (int, str):
137-
argument = true_converter(parsed)
138-
elif true_converter is bool:
139-
argument = _boolconverter(parsed)
140-
else:
141-
argument = true_converter(context, parsed)
248+
argument = true_converter(context, parsed)
142249
if inspect.iscoroutine(argument):
143250
argument = await argument
144-
except BadArgument:
251+
except BadArgument as e:
252+
if e.name is None:
253+
e.name = param.name
254+
145255
raise
146256
except Exception as e:
147257
raise ArgumentParsingFailed(
148-
f"Invalid argument parsed at `{param.name}` in command `{self.name}`."
149-
f" Expected type {converter} got {type(parsed)}.",
150-
e,
258+
f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None
151259
) from e
152260
return argument
153261

@@ -170,26 +278,40 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di
170278
try:
171279
argument = parsed.pop(index)
172280
except (KeyError, IndexError):
281+
if self._is_optional_argument(param.annotation): # parameter is optional and at the end.
282+
args.append(param.default if param.default is not param.empty else None)
283+
continue
284+
173285
if param.default is param.empty:
174-
raise MissingRequiredArgument(param)
286+
raise MissingRequiredArgument(argname=param.name)
287+
175288
args.append(param.default)
176289
else:
177-
argument = await self._convert_types(context, param, argument)
178-
args.append(argument)
290+
_parsed_arg = await self._convert_types(context, param, argument)
291+
292+
if _parsed_arg is EMPTY:
293+
parsed[index] = argument
294+
index -= 1
295+
args.append(param.default if param.default is not param.empty else None)
296+
297+
continue
298+
else:
299+
args.append(_parsed_arg)
300+
179301
elif param.kind == param.KEYWORD_ONLY:
180302
rest = " ".join(parsed.values())
181303
if rest.startswith(" "):
182304
rest = rest.lstrip(" ")
183305
if rest:
184306
rest = await self._convert_types(context, param, rest)
185307
elif param.default is param.empty:
186-
raise MissingRequiredArgument(param)
308+
raise MissingRequiredArgument(argname=param.name)
187309
else:
188310
rest = param.default
189311
kwargs[param.name] = rest
190312
parsed.clear()
191313
break
192-
elif param.VAR_POSITIONAL:
314+
elif param.kind == param.VAR_POSITIONAL:
193315
args.extend([await self._convert_types(context, param, argument) for argument in parsed.values()])
194316
parsed.clear()
195317
break

twitchio/ext/commands/errors.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
2222
DEALINGS IN THE SOFTWARE.
2323
"""
24+
from __future__ import annotations
25+
26+
from typing import Optional, TYPE_CHECKING
27+
28+
if TYPE_CHECKING:
29+
from .core import Command
2430

2531

2632
class TwitchCommandError(Exception):
@@ -38,29 +44,52 @@ class InvalidCog(TwitchCommandError):
3844

3945

4046
class MissingRequiredArgument(TwitchCommandError):
41-
pass
47+
def __init__(self, *args, argname: Optional[str] = None) -> None:
48+
self.name: str = argname or "unknown"
49+
50+
if args:
51+
super().__init__(*args)
52+
else:
53+
super().__init__(f"Missing required argument `{self.name}`")
4254

4355

4456
class BadArgument(TwitchCommandError):
45-
def __init__(self, message: str):
57+
def __init__(self, message: str, argname: Optional[str] = None):
58+
self.name: str = argname # type: ignore # this'll get fixed in the parser handler
4659
self.message = message
4760
super().__init__(message)
4861

4962

5063
class ArgumentParsingFailed(BadArgument):
51-
def __init__(self, message: str, original: Exception):
52-
self.original = original
53-
super().__init__(message)
64+
def __init__(
65+
self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None
66+
):
67+
self.original: Exception = original
68+
self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none.
69+
self.expected_type: Optional[type] = expected
70+
71+
Exception.__init__(self, message) # bypass badArgument
72+
73+
74+
class UnionArgumentParsingFailed(ArgumentParsingFailed):
75+
def __init__(self, argname: str, expected: tuple[type, ...]):
76+
self.name: str = argname
77+
self.expected_type: tuple[type, ...] = expected
78+
79+
self.message = f"Failed to convert argument `{self.name}` to any of the valid options"
80+
Exception.__init__(self, self.message)
5481

5582

5683
class CommandNotFound(TwitchCommandError):
57-
pass
84+
def __init__(self, message: str, name: str) -> None:
85+
self.name: str = name
86+
super().__init__(message)
5887

5988

6089
class CommandOnCooldown(TwitchCommandError):
61-
def __init__(self, command, retry_after):
62-
self.command = command
63-
self.retry_after = retry_after
90+
def __init__(self, command: Command, retry_after: float):
91+
self.command: Command = command
92+
self.retry_after: float = retry_after
6493
super().__init__(f"Command <{command.name}> is on cooldown. Try again in ({retry_after:.2f})s")
6594

6695

0 commit comments

Comments
 (0)