2828}
2929
3030
31- def last_child_of_type (node : SgNode , type_ : str ) -> SgNode | None :
31+ def _get_last_child_of_type (node : SgNode , type_ : str ) -> SgNode | None :
3232 return last_child if (children := node .children ()) and (last_child := children [- 1 ]).kind () == type_ else None
3333
3434
35- def find_identifiers_in_children (node : SgNode ) -> Iterable [SgNode ]:
35+ def _find_identifiers_in_children (node : SgNode ) -> Iterable [SgNode ]:
3636 for child in node .children ():
3737 if child .kind () == "identifier" :
3838 yield child
3939
4040
41- def is_inside_inner_function (root : SgNode , node : SgNode ) -> bool :
41+ def _is_inside_inner_function (root : SgNode , node : SgNode ) -> bool :
4242 for ancestor in node .ancestors ():
4343 if ancestor .kind () == "function_definition" :
4444 return ancestor != root
4545 return False
4646
4747
48- def is_inside_inner_function_or_class (root : SgNode , node : SgNode ) -> bool :
48+ def _is_inside_inner_function_or_class (root : SgNode , node : SgNode ) -> bool :
4949 for ancestor in node .ancestors ():
5050 if ancestor .kind () in {"function_definition" , "class_definition" }:
5151 return ancestor != root
5252 return False
5353
5454
55- def find_identifiers_in_function_parameter (node : SgNode ) -> Iterable [SgNode ]:
56- match node .kind ():
57- case "default_parameter" | "typed_default_parameter" :
58- if name := node .field ("name" ):
59- yield name
60- case "identifier" :
61- yield node
62- case _:
63- yield from find_identifiers_in_children (node )
64-
65-
66- def find_identifiers_in_import (node : SgNode ) -> Iterable [SgNode ]:
55+ def _find_identifiers_in_import_statement (node : SgNode ) -> Iterable [SgNode ]:
6756 match tuple ((child .kind (), child ) for child in node .children ()):
6857 case (("from" , _), _, ("import" , _), * name_nodes ) | (("import" , _), * name_nodes ):
6958 for kind , name_node in name_nodes :
7059 match kind :
7160 case "dotted_name" :
72- if identifier := last_child_of_type (name_node , "identifier" ):
61+ if identifier := _get_last_child_of_type (name_node , "identifier" ):
7362 yield identifier
7463 case "aliased_import" :
7564 if alias := name_node .field ("alias" ):
7665 yield alias
7766
7867
79- def find_identifiers_in_function_body (node : SgNode ) -> Iterable [SgNode ]: # noqa: C901, PLR0912
68+ def _find_identifiers_made_by_node (node : SgNode ) -> Iterable [SgNode ]: # noqa: C901, PLR0912
8069 match node .kind ():
8170 case "assignment" | "augmented_assignment" :
8271 if not (left := node .field ("left" )):
8372 return
8473 match left .kind ():
8574 case "pattern_list" | "tuple_pattern" :
86- yield from find_identifiers_in_children (left )
75+ yield from _find_identifiers_in_children (left )
8776 case "identifier" :
8877 yield left
8978 case "named_expression" :
@@ -94,18 +83,18 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa
9483 yield name
9584 for function in node .find_all (kind = "function_definition" ):
9685 for nonlocal_statement in node .find_all (kind = "nonlocal_statement" ):
97- if is_inside_inner_function (root = function , node = nonlocal_statement ):
86+ if _is_inside_inner_function (root = function , node = nonlocal_statement ):
9887 continue
99- yield from find_identifiers_in_children (nonlocal_statement )
88+ yield from _find_identifiers_in_children (nonlocal_statement )
10089 case "function_definition" :
10190 if name := node .field ("name" ):
10291 yield name
10392 for nonlocal_statement in node .find_all (kind = "nonlocal_statement" ):
104- if is_inside_inner_function (root = node , node = nonlocal_statement ):
93+ if _is_inside_inner_function (root = node , node = nonlocal_statement ):
10594 continue
106- yield from find_identifiers_in_children (nonlocal_statement )
95+ yield from _find_identifiers_in_children (nonlocal_statement )
10796 case "import_from_statement" | "import_statement" :
108- yield from find_identifiers_in_import (node )
97+ yield from _find_identifiers_in_import_statement (node )
10998 case "as_pattern" :
11099 match tuple ((child .kind (), child ) for child in node .children ()):
111100 case (
@@ -116,62 +105,66 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa
116105 case "keyword_pattern" :
117106 match tuple ((child .kind (), child ) for child in node .children ()):
118107 case (("identifier" , _), ("=" , _), ("dotted_name" , alias )):
119- if identifier := last_child_of_type (alias , "identifier" ):
108+ if identifier := _get_last_child_of_type (alias , "identifier" ):
120109 yield identifier
121110 case "list_pattern" | "tuple_pattern" :
122111 for child in node .children ():
123112 if (
124113 child .kind () == "case_pattern"
125- and (dotted_name := last_child_of_type (child , "dotted_name" ))
126- and (identifier := last_child_of_type (dotted_name , "identifier" ))
114+ and (dotted_name := _get_last_child_of_type (child , "dotted_name" ))
115+ and (identifier := _get_last_child_of_type (dotted_name , "identifier" ))
127116 ):
128117 yield identifier
129118 case "splat_pattern" | "global_statement" | "nonlocal_statement" :
130- yield from find_identifiers_in_children (node )
119+ yield from _find_identifiers_in_children (node )
131120 case "dict_pattern" :
132121 for child in node .children ():
133122 if (
134123 child .kind () == "case_pattern"
135124 and (previous_child := child .prev ())
136125 and previous_child .kind () == ":"
137- and (dotted_name := last_child_of_type (child , "dotted_name" ))
138- and (identifier := last_child_of_type (dotted_name , "identifier" ))
126+ and (dotted_name := _get_last_child_of_type (child , "dotted_name" ))
127+ and (identifier := _get_last_child_of_type (dotted_name , "identifier" ))
139128 ):
140129 yield identifier
141130 case "for_statement" :
142131 if left := node .field ("left" ):
143132 yield from left .find_all (kind = "identifier" )
144133
145134
146- def find_definitions_in_scope_grouped_by_name (root : SgNode ) -> dict [str , list [SgNode ]]:
147- definition_map = defaultdict (list )
148-
149- if parameters := root .field ("parameters" ):
150- for parameter in parameters .children ():
151- for identifier in find_identifiers_in_function_parameter (parameter ):
152- definition_map [identifier .text ()].append (parameter )
153-
154- for node in root .find_all (DEFINITION_RULE ):
155- if is_inside_inner_function_or_class (root , node ) or node == root :
135+ def _find_identifiers_in_scope (node : SgNode ) -> Iterable [tuple [SgNode , SgNode ]]:
136+ for child in node .find_all (DEFINITION_RULE ):
137+ if _is_inside_inner_function_or_class (node , child ) or child == node :
156138 continue
157- for identifier in find_identifiers_in_function_body (node ):
158- definition_map [identifier .text ()].append (node )
139+ for identifier in _find_identifiers_made_by_node (child ):
140+ yield identifier , child
141+
159142
160- return definition_map
143+ def _find_identifiers_in_function_parameter (node : SgNode ) -> Iterable [SgNode ]:
144+ match node .kind ():
145+ case "default_parameter" | "typed_default_parameter" :
146+ if name := node .field ("name" ):
147+ yield name
148+ case "identifier" :
149+ yield node
150+ case _:
151+ yield from _find_identifiers_in_children (node )
161152
162153
163- def find_definitions_in_module (root : SgNode ) -> Iterable [list [SgNode ]]:
154+ def find_all_definitions_in_functions (root : SgNode ) -> Iterable [list [SgNode ]]:
164155 for function in root .find_all (kind = "function_definition" ):
165- yield from find_definitions_in_scope_grouped_by_name ( function ). values ( )
156+ definition_map = defaultdict ( list )
166157
158+ if parameters := function .field ("parameters" ):
159+ for parameter in parameters .children ():
160+ for identifier in _find_identifiers_in_function_parameter (parameter ):
161+ definition_map [identifier .text ()].append (parameter )
167162
168- def has_global_import_with_name (root : SgNode , name : str ) -> bool :
169- for import_statement in root .find_all (
170- {"rule" : {"any" : [{"kind" : "import_from_statement" }, {"kind" : "import_statement" }]}}
171- ):
172- if is_inside_inner_function_or_class (root , import_statement ):
173- continue
174- for identifier in find_identifiers_in_import (import_statement ):
175- if identifier .text () == name :
176- return True
177- return False
163+ for identifier , node in _find_identifiers_in_scope (function ):
164+ definition_map [identifier .text ()].append (node )
165+
166+ yield from definition_map .values ()
167+
168+
169+ def has_global_identifier_with_name (root : SgNode , name : str ) -> bool :
170+ return name in {identifier .text () for identifier , _ in _find_identifiers_in_scope (root )}
0 commit comments