2727
2828import itertools
2929import copy
30+ import types
3031from typing import Any , Union , Optional , Callable , Awaitable , Tuple , TYPE_CHECKING , List , Type , Set , TypeVar
3132from typing_extensions import Literal
3233
3637from . import builtin_converter
3738
3839if TYPE_CHECKING :
40+ import sys
41+
3942 from twitchio import Message , Chatter , PartialChatter , Channel , User , PartialUser
4043 from . import Cog , Bot
4144 from .stringparser import StringParser
45+
46+ if sys .version_info >= (3 , 10 ):
47+ UnionT = Union [types .UnionType , Union ]
48+ else :
49+ UnionT = Union
50+
51+
4252__all__ = ("Command" , "command" , "Group" , "Context" , "cooldown" )
4353
4454
55+ class EmptyArgumentSentinel :
56+ def __repr__ (self ) -> str :
57+ return "<EMPTY>"
58+
59+ def __eq__ (self , __value : object ) -> bool :
60+ return False
61+
62+
63+ EMPTY = EmptyArgumentSentinel ()
64+
65+
4566def _boolconverter (param : str ):
4667 param = param .lower ()
4768 if param in {"yes" , "y" , "1" , "true" , "on" }:
@@ -114,40 +135,127 @@ def full_name(self) -> str:
114135 return self ._name
115136 return f"{ self .parent .full_name } { self ._name } "
116137
117- def _resolve_converter (self , converter : Union [Callable , Awaitable , type ]) -> Union [Callable [..., Any ]]:
138+ def _is_optional_argument (self , converter : Any ):
139+ return (getattr (converter , "__origin__" , None ) is Union or isinstance (converter , types .UnionType )) and type (
140+ None
141+ ) in converter .__args__
142+
143+ def resolve_union_callback (self , name : str , converter : UnionT ) -> Callable [[Context , str ], Any ]:
144+ # print(type(converter), converter.__args__)
145+
146+ args = converter .__args__ # type: ignore # pyright doesnt like this
147+
148+ async def _resolve (context : Context , arg : str ) -> Any :
149+ t = EMPTY
150+ last = None
151+
152+ for original in args :
153+ underlying = self ._resolve_converter (name , original )
154+
155+ try :
156+ t : Any = underlying (context , arg )
157+ if inspect .iscoroutine (t ):
158+ t = await t
159+
160+ break
161+ except Exception as l :
162+ last = l
163+ t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back
164+ continue
165+
166+ if t is EMPTY :
167+ raise UnionArgumentParsingFailed (name , args )
168+
169+ return t
170+
171+ return _resolve
172+
173+ def resolve_optional_callback (self , name : str , converter : Any ) -> Callable [[Context , str ], Any ]:
174+ underlying = self ._resolve_converter (name , converter .__args__ [0 ])
175+
176+ async def _resolve (context : Context , arg : str ) -> Any :
177+ try :
178+ t : Any = underlying (context , arg )
179+ if inspect .iscoroutine (t ):
180+ t = await t
181+
182+ except Exception :
183+ return EMPTY # instruct the parser to roll back and ignore this argument
184+
185+ return t
186+
187+ return _resolve
188+
189+ def _resolve_converter (self , name : str , converter : Union [Callable , Awaitable , type ]) -> Callable [..., Any ]:
118190 if (
119191 isinstance (converter , type )
120192 and converter .__module__ .startswith ("twitchio" )
121193 and converter in builtin_converter ._mapping
122194 ):
123- return builtin_converter ._mapping [converter ]
124- return converter
195+ return self ._convert_builtin_type (name , converter , builtin_converter ._mapping [converter ])
196+
197+ elif converter is bool :
198+ converter = self ._convert_builtin_type (name , bool , _boolconverter )
199+
200+ elif converter in (str , int ):
201+ converter = self ._convert_builtin_type (name , converter , converter ) # type: ignore
202+
203+ elif self ._is_optional_argument (converter ):
204+ return self .resolve_optional_callback (name , converter )
205+
206+ elif isinstance (converter , types .UnionType ) or getattr (converter , "__origin__" , None ) is Union :
207+ return self .resolve_union_callback (name , converter ) # type: ignore
208+
209+ elif hasattr (converter , "__metadata__" ): # Annotated
210+ annotated = converter .__metadata__ # type: ignore
211+ return self ._resolve_converter (name , annotated [0 ])
212+
213+ return converter # type: ignore
214+
215+ def _convert_builtin_type (
216+ self , arg_name : str , original : type , converter : Union [Callable [[str ], Any ], Callable [[str ], Awaitable [Any ]]]
217+ ) -> Callable [[Context , str ], Awaitable [Any ]]:
218+ async def resolve (_ , arg : str ) -> Any :
219+ try :
220+ t = converter (arg )
221+
222+ if inspect .iscoroutine (t ):
223+ t = await t
224+
225+ return t
226+ except Exception as e :
227+ raise ArgumentParsingFailed (
228+ f"Failed to convert `{ arg } ` to expected type { original .__name__ } for argument `{ arg_name } `" ,
229+ original = e ,
230+ argname = arg_name ,
231+ expected = original ,
232+ ) from e
233+
234+ return resolve
125235
126236 async def _convert_types (self , context : Context , param : inspect .Parameter , parsed : str ) -> Any :
127237 converter = param .annotation
238+
128239 if converter is param .empty :
129240 if param .default in (param .empty , None ):
130241 converter = str
131242 else :
132243 converter = type (param .default )
133- true_converter = self ._resolve_converter (converter )
244+
245+ true_converter = self ._resolve_converter (param .name , converter )
134246
135247 try :
136- if true_converter in (int , str ):
137- argument = true_converter (parsed )
138- elif true_converter is bool :
139- argument = _boolconverter (parsed )
140- else :
141- argument = true_converter (context , parsed )
248+ argument = true_converter (context , parsed )
142249 if inspect .iscoroutine (argument ):
143250 argument = await argument
144- except BadArgument :
251+ except BadArgument as e :
252+ if e .name is None :
253+ e .name = param .name
254+
145255 raise
146256 except Exception as e :
147257 raise ArgumentParsingFailed (
148- f"Invalid argument parsed at `{ param .name } ` in command `{ self .name } `."
149- f" Expected type { converter } got { type (parsed )} ." ,
150- e ,
258+ f"Failed to parse `{ parsed } ` for argument { param .name } " , original = e , argname = param .name , expected = None
151259 ) from e
152260 return argument
153261
@@ -170,26 +278,40 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di
170278 try :
171279 argument = parsed .pop (index )
172280 except (KeyError , IndexError ):
281+ if self ._is_optional_argument (param .annotation ): # parameter is optional and at the end.
282+ args .append (param .default if param .default is not param .empty else None )
283+ continue
284+
173285 if param .default is param .empty :
174- raise MissingRequiredArgument (param )
286+ raise MissingRequiredArgument (argname = param .name )
287+
175288 args .append (param .default )
176289 else :
177- argument = await self ._convert_types (context , param , argument )
178- args .append (argument )
290+ _parsed_arg = await self ._convert_types (context , param , argument )
291+
292+ if _parsed_arg is EMPTY :
293+ parsed [index ] = argument
294+ index -= 1
295+ args .append (param .default if param .default is not param .empty else None )
296+
297+ continue
298+ else :
299+ args .append (_parsed_arg )
300+
179301 elif param .kind == param .KEYWORD_ONLY :
180302 rest = " " .join (parsed .values ())
181303 if rest .startswith (" " ):
182304 rest = rest .lstrip (" " )
183305 if rest :
184306 rest = await self ._convert_types (context , param , rest )
185307 elif param .default is param .empty :
186- raise MissingRequiredArgument (param )
308+ raise MissingRequiredArgument (argname = param . name )
187309 else :
188310 rest = param .default
189311 kwargs [param .name ] = rest
190312 parsed .clear ()
191313 break
192- elif param .VAR_POSITIONAL :
314+ elif param .kind == param . VAR_POSITIONAL :
193315 args .extend ([await self ._convert_types (context , param , argument ) for argument in parsed .values ()])
194316 parsed .clear ()
195317 break
0 commit comments