2424import argparse
2525import inspect
2626import textwrap
27- from typing import Any , Callable , Dict , Iterable , List , Optional , Type , TypeVar
27+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Type , TypeVar
2828
2929import drgn
3030
@@ -244,6 +244,29 @@ def _call(self,
244244 """
245245 raise NotImplementedError ()
246246
247+ def _valid_input_types (self ) -> Set [str ]:
248+ """
249+ Returns a set of strings which are the canonicalized names of valid input types
250+ for this command
251+ """
252+ assert self .input_type is not None
253+ return {type_canonicalize_name (self .input_type )}
254+
255+ def __input_type_check (
256+ self , objs : Iterable [drgn .Object ]) -> Iterable [drgn .Object ]:
257+ valid_input_types = self ._valid_input_types ()
258+ prev_type = None
259+ for obj in objs :
260+ cur_type = type_canonical_name (obj .type_ )
261+ if cur_type not in valid_input_types or (prev_type and
262+ cur_type != prev_type ):
263+ raise CommandError (
264+ self .name ,
265+ f'expected input of type { self .input_type } , but received '
266+ f'type { obj .type_ } ' )
267+ prev_type = cur_type
268+ yield obj
269+
247270 def __invalid_memory_objects_check (self , objs : Iterable [drgn .Object ],
248271 fatal : bool ) -> Iterable [drgn .Object ]:
249272 """
@@ -281,15 +304,19 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
281304 # the command is running.
282305 #
283306 try :
284- result = self ._call (objs )
307+ if self .input_type and objs :
308+ result = self ._call (self .__input_type_check (objs ))
309+ else :
310+ result = self ._call (objs )
311+
285312 if result is not None :
286313 #
287314 # The whole point of the SingleInputCommands are that
288315 # they don't stop executing in the first encounter of
289316 # a bad dereference. That's why we check here whether
290317 # the command that we are running is a subclass of
291318 # SingleInputCommand and we set the `fatal` flag
292- # accordinly .
319+ # accordingly .
293320 #
294321 yield from self .__invalid_memory_objects_check (
295322 result , not issubclass (self .__class__ , SingleInputCommand ))
@@ -634,22 +661,6 @@ def pretty_print(self, objs: Iterable[drgn.Object]) -> None:
634661 # pylint: disable=missing-docstring
635662 raise NotImplementedError
636663
637- def check_input_type (self ,
638- objs : Iterable [drgn .Object ]) -> Iterable [drgn .Object ]:
639- """
640- This function acts as a generator, checking that each passed object
641- matches the input type for the command
642- """
643- assert self .input_type is not None
644- type_name = type_canonicalize_name (self .input_type )
645- for obj in objs :
646- if type_canonical_name (obj .type_ ) != type_name :
647- raise CommandError (
648- self .name ,
649- f'expected input of type { self .input_type } , but received '
650- f'type { obj .type_ } ' )
651- yield obj
652-
653664 def _call ( # type: ignore[return]
654665 self ,
655666 objs : Iterable [drgn .Object ]) -> Optional [Iterable [drgn .Object ]]:
@@ -658,7 +669,7 @@ def _call( # type: ignore[return]
658669 verifying the types as we go.
659670 """
660671 assert self .input_type is not None
661- self .pretty_print (self . check_input_type ( objs ) )
672+ self .pretty_print (objs )
662673
663674
664675class Locator (Command ):
@@ -673,6 +684,26 @@ class Locator(Command):
673684
674685 output_type : Optional [str ] = None
675686
687+ def _valid_input_types (self ) -> Set [str ]:
688+ """
689+ Some Locators support multiple input types. Check for InputHandler
690+ implementations to expand the set of valid input types.
691+ """
692+ assert self .input_type is not None
693+ valid_types = [type_canonicalize_name (self .input_type )]
694+
695+ for (_ , method ) in inspect .getmembers (self , inspect .ismethod ):
696+ if hasattr (method , "input_typename_handled" ):
697+ valid_types .append (
698+ type_canonicalize_name (method .input_typename_handled ))
699+
700+ valid_types += [
701+ type_canonicalize_name (type_ )
702+ for type_ , class_ in Walker .allWalkers .items ()
703+ ]
704+
705+ return set (valid_types )
706+
676707 def no_input (self ) -> Iterable [drgn .Object ]:
677708 # pylint: disable=missing-docstring
678709 raise CommandError (self .name , 'command requires an input' )
0 commit comments