Skip to content

Commit 8d0dd85

Browse files
authored
Support from typing import Final style (#28)
1 parent a18877b commit 8d0dd85

File tree

5 files changed

+283
-204
lines changed

5 files changed

+283
-204
lines changed

auto_typing_final/finder.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,62 +28,51 @@
2828
}
2929

3030

31-
def last_child_of_type(node: SgNode, type_: str) -> SgNode | None:
31+
def _get_last_child_of_type(node: SgNode, type_: str) -> SgNode | None:
3232
return last_child if (children := node.children()) and (last_child := children[-1]).kind() == type_ else None
3333

3434

35-
def find_identifiers_in_children(node: SgNode) -> Iterable[SgNode]:
35+
def _find_identifiers_in_children(node: SgNode) -> Iterable[SgNode]:
3636
for child in node.children():
3737
if child.kind() == "identifier":
3838
yield child
3939

4040

41-
def is_inside_inner_function(root: SgNode, node: SgNode) -> bool:
41+
def _is_inside_inner_function(root: SgNode, node: SgNode) -> bool:
4242
for ancestor in node.ancestors():
4343
if ancestor.kind() == "function_definition":
4444
return ancestor != root
4545
return False
4646

4747

48-
def is_inside_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
48+
def _is_inside_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
4949
for ancestor in node.ancestors():
5050
if ancestor.kind() in {"function_definition", "class_definition"}:
5151
return ancestor != root
5252
return False
5353

5454

55-
def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[SgNode]:
56-
match node.kind():
57-
case "default_parameter" | "typed_default_parameter":
58-
if name := node.field("name"):
59-
yield name
60-
case "identifier":
61-
yield node
62-
case _:
63-
yield from find_identifiers_in_children(node)
64-
65-
66-
def find_identifiers_in_import(node: SgNode) -> Iterable[SgNode]:
55+
def _find_identifiers_in_import_statement(node: SgNode) -> Iterable[SgNode]:
6756
match tuple((child.kind(), child) for child in node.children()):
6857
case (("from", _), _, ("import", _), *name_nodes) | (("import", _), *name_nodes):
6958
for kind, name_node in name_nodes:
7059
match kind:
7160
case "dotted_name":
72-
if identifier := last_child_of_type(name_node, "identifier"):
61+
if identifier := _get_last_child_of_type(name_node, "identifier"):
7362
yield identifier
7463
case "aliased_import":
7564
if alias := name_node.field("alias"):
7665
yield alias
7766

7867

79-
def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa: C901, PLR0912
68+
def _find_identifiers_made_by_node(node: SgNode) -> Iterable[SgNode]: # noqa: C901, PLR0912
8069
match node.kind():
8170
case "assignment" | "augmented_assignment":
8271
if not (left := node.field("left")):
8372
return
8473
match left.kind():
8574
case "pattern_list" | "tuple_pattern":
86-
yield from find_identifiers_in_children(left)
75+
yield from _find_identifiers_in_children(left)
8776
case "identifier":
8877
yield left
8978
case "named_expression":
@@ -94,18 +83,18 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa
9483
yield name
9584
for function in node.find_all(kind="function_definition"):
9685
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
97-
if is_inside_inner_function(root=function, node=nonlocal_statement):
86+
if _is_inside_inner_function(root=function, node=nonlocal_statement):
9887
continue
99-
yield from find_identifiers_in_children(nonlocal_statement)
88+
yield from _find_identifiers_in_children(nonlocal_statement)
10089
case "function_definition":
10190
if name := node.field("name"):
10291
yield name
10392
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
104-
if is_inside_inner_function(root=node, node=nonlocal_statement):
93+
if _is_inside_inner_function(root=node, node=nonlocal_statement):
10594
continue
106-
yield from find_identifiers_in_children(nonlocal_statement)
95+
yield from _find_identifiers_in_children(nonlocal_statement)
10796
case "import_from_statement" | "import_statement":
108-
yield from find_identifiers_in_import(node)
97+
yield from _find_identifiers_in_import_statement(node)
10998
case "as_pattern":
11099
match tuple((child.kind(), child) for child in node.children()):
111100
case (
@@ -116,62 +105,66 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa
116105
case "keyword_pattern":
117106
match tuple((child.kind(), child) for child in node.children()):
118107
case (("identifier", _), ("=", _), ("dotted_name", alias)):
119-
if identifier := last_child_of_type(alias, "identifier"):
108+
if identifier := _get_last_child_of_type(alias, "identifier"):
120109
yield identifier
121110
case "list_pattern" | "tuple_pattern":
122111
for child in node.children():
123112
if (
124113
child.kind() == "case_pattern"
125-
and (dotted_name := last_child_of_type(child, "dotted_name"))
126-
and (identifier := last_child_of_type(dotted_name, "identifier"))
114+
and (dotted_name := _get_last_child_of_type(child, "dotted_name"))
115+
and (identifier := _get_last_child_of_type(dotted_name, "identifier"))
127116
):
128117
yield identifier
129118
case "splat_pattern" | "global_statement" | "nonlocal_statement":
130-
yield from find_identifiers_in_children(node)
119+
yield from _find_identifiers_in_children(node)
131120
case "dict_pattern":
132121
for child in node.children():
133122
if (
134123
child.kind() == "case_pattern"
135124
and (previous_child := child.prev())
136125
and previous_child.kind() == ":"
137-
and (dotted_name := last_child_of_type(child, "dotted_name"))
138-
and (identifier := last_child_of_type(dotted_name, "identifier"))
126+
and (dotted_name := _get_last_child_of_type(child, "dotted_name"))
127+
and (identifier := _get_last_child_of_type(dotted_name, "identifier"))
139128
):
140129
yield identifier
141130
case "for_statement":
142131
if left := node.field("left"):
143132
yield from left.find_all(kind="identifier")
144133

145134

146-
def find_definitions_in_scope_grouped_by_name(root: SgNode) -> dict[str, list[SgNode]]:
147-
definition_map = defaultdict(list)
148-
149-
if parameters := root.field("parameters"):
150-
for parameter in parameters.children():
151-
for identifier in find_identifiers_in_function_parameter(parameter):
152-
definition_map[identifier.text()].append(parameter)
153-
154-
for node in root.find_all(DEFINITION_RULE):
155-
if is_inside_inner_function_or_class(root, node) or node == root:
135+
def _find_identifiers_in_scope(node: SgNode) -> Iterable[tuple[SgNode, SgNode]]:
136+
for child in node.find_all(DEFINITION_RULE):
137+
if _is_inside_inner_function_or_class(node, child) or child == node:
156138
continue
157-
for identifier in find_identifiers_in_function_body(node):
158-
definition_map[identifier.text()].append(node)
139+
for identifier in _find_identifiers_made_by_node(child):
140+
yield identifier, child
141+
159142

160-
return definition_map
143+
def _find_identifiers_in_function_parameter(node: SgNode) -> Iterable[SgNode]:
144+
match node.kind():
145+
case "default_parameter" | "typed_default_parameter":
146+
if name := node.field("name"):
147+
yield name
148+
case "identifier":
149+
yield node
150+
case _:
151+
yield from _find_identifiers_in_children(node)
161152

162153

163-
def find_definitions_in_module(root: SgNode) -> Iterable[list[SgNode]]:
154+
def find_all_definitions_in_functions(root: SgNode) -> Iterable[list[SgNode]]:
164155
for function in root.find_all(kind="function_definition"):
165-
yield from find_definitions_in_scope_grouped_by_name(function).values()
156+
definition_map = defaultdict(list)
166157

158+
if parameters := function.field("parameters"):
159+
for parameter in parameters.children():
160+
for identifier in _find_identifiers_in_function_parameter(parameter):
161+
definition_map[identifier.text()].append(parameter)
167162

168-
def has_global_import_with_name(root: SgNode, name: str) -> bool:
169-
for import_statement in root.find_all(
170-
{"rule": {"any": [{"kind": "import_from_statement"}, {"kind": "import_statement"}]}}
171-
):
172-
if is_inside_inner_function_or_class(root, import_statement):
173-
continue
174-
for identifier in find_identifiers_in_import(import_statement):
175-
if identifier.text() == name:
176-
return True
177-
return False
163+
for identifier, node in _find_identifiers_in_scope(function):
164+
definition_map[identifier.text()].append(node)
165+
166+
yield from definition_map.values()
167+
168+
169+
def has_global_identifier_with_name(root: SgNode, name: str) -> bool:
170+
return name in {identifier.text() for identifier, _ in _find_identifiers_in_scope(root)}

auto_typing_final/lsp.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from pygls import server
1010
from pygls.workspace import TextDocument
1111

12-
from auto_typing_final.finder import has_global_import_with_name
13-
from auto_typing_final.transform import AddFinal, AppliedOperation, make_operations_from_root
12+
from auto_typing_final.transform import IMPORT_MODES_TO_IMPORT_CONFIGS, AddFinal, Edit, ImportMode, make_replacements
1413

1514
LSP_SERVER = server.LanguageServer(name="auto-typing-final", version=version("auto-typing-final"), max_workers=5)
15+
IMPORT_CONFIG = IMPORT_MODES_TO_IMPORT_CONFIGS[ImportMode.typing_final]
1616

1717

1818
@attr.define
@@ -26,43 +26,40 @@ class DiagnosticData:
2626
fix: Fix
2727

2828

29-
def make_typing_import() -> lsp.TextEdit:
29+
def make_import_text_edit(import_text: str) -> lsp.TextEdit:
3030
return lsp.TextEdit(
3131
range=lsp.Range(start=lsp.Position(line=0, character=0), end=lsp.Position(line=0, character=0)),
32-
new_text="import typing\n",
32+
new_text=f"{import_text}\n",
3333
)
3434

3535

36-
def make_diagnostic_text_edits(applied_operation: AppliedOperation) -> Iterable[lsp.TextEdit]:
37-
for applied_edit in applied_operation.edits:
38-
node_range = applied_edit.node.range()
39-
yield lsp.TextEdit(
40-
range=lsp.Range(
41-
start=lsp.Position(line=node_range.start.line, character=node_range.start.column),
42-
end=lsp.Position(line=node_range.end.line, character=node_range.end.column),
43-
),
44-
new_text=applied_edit.edit.inserted_text,
45-
)
36+
def make_text_edit(edit: Edit) -> lsp.TextEdit:
37+
node_range = edit.node.range()
38+
return lsp.TextEdit(
39+
range=lsp.Range(
40+
start=lsp.Position(line=node_range.start.line, character=node_range.start.column),
41+
end=lsp.Position(line=node_range.end.line, character=node_range.end.column),
42+
),
43+
new_text=edit.new_text,
44+
)
4645

4746

4847
def make_diagnostics(source: str) -> Iterable[lsp.Diagnostic]:
49-
root = SgRoot(source, "python").root()
50-
has_import = has_global_import_with_name(root, "typing")
48+
result = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)
5149

52-
for applied_operation in make_operations_from_root(root):
53-
if isinstance(applied_operation.operation, AddFinal):
50+
for replacement in result.replacements:
51+
if replacement.operation_type == AddFinal:
5452
fix_message = f"{LSP_SERVER.name}: Add typing.Final"
5553
diagnostic_message = "Missing typing.Final"
5654
else:
5755
fix_message = f"{LSP_SERVER.name}: Remove typing.Final"
5856
diagnostic_message = "Unexpected typing.Final"
5957

60-
fix = Fix(message=fix_message, text_edits=list(make_diagnostic_text_edits(applied_operation)))
61-
62-
if isinstance(applied_operation.operation, AddFinal) and not has_import:
63-
fix.text_edits.append(make_typing_import())
58+
fix = Fix(message=fix_message, text_edits=[make_text_edit(edit) for edit in replacement.edits])
59+
if result.import_text:
60+
fix.text_edits.append(make_import_text_edit(result.import_text))
6461

65-
for applied_edit in applied_operation.edits:
62+
for applied_edit in replacement.edits:
6663
node_range = applied_edit.node.range()
6764
yield lsp.Diagnostic(
6865
range=lsp.Range(
@@ -77,36 +74,26 @@ def make_diagnostics(source: str) -> Iterable[lsp.Diagnostic]:
7774

7875

7976
def make_fixall_text_edits(source: str) -> Iterable[lsp.TextEdit]:
80-
root = SgRoot(source, "python").root()
81-
has_import = has_global_import_with_name(root, "typing")
82-
has_add_final_operation = False
83-
84-
for applied_operation in make_operations_from_root(root):
85-
if isinstance(applied_operation.operation, AddFinal):
86-
has_add_final_operation = True
87-
yield from make_diagnostic_text_edits(applied_operation)
88-
89-
if has_add_final_operation and not has_import:
90-
yield make_typing_import()
91-
92-
93-
def make_quickfix_action(diagnostic: lsp.Diagnostic, text_document: TextDocument) -> lsp.CodeAction:
94-
data = cattrs.structure(diagnostic.data, DiagnosticData)
95-
return lsp.CodeAction(
96-
title=data.fix.message,
97-
kind=lsp.CodeActionKind.QuickFix,
98-
data=text_document.uri,
99-
edit=lsp.WorkspaceEdit(
100-
document_changes=[
101-
lsp.TextDocumentEdit(
102-
text_document=lsp.OptionalVersionedTextDocumentIdentifier(
103-
uri=text_document.uri, version=text_document.version
104-
),
105-
edits=data.fix.text_edits, # type: ignore[arg-type]
106-
)
107-
]
108-
),
109-
diagnostics=[diagnostic],
77+
result = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)
78+
79+
for replacement in result.replacements:
80+
for edit in replacement.edits:
81+
yield make_text_edit(edit)
82+
83+
if result.import_text:
84+
yield make_import_text_edit(result.import_text)
85+
86+
87+
def make_workspace_edit(text_document: TextDocument, text_edits: list[lsp.TextEdit]) -> lsp.WorkspaceEdit:
88+
return lsp.WorkspaceEdit(
89+
document_changes=[
90+
lsp.TextDocumentEdit(
91+
text_document=lsp.OptionalVersionedTextDocumentIdentifier(
92+
uri=text_document.uri, version=text_document.version
93+
),
94+
edits=text_edits, # type: ignore[arg-type]
95+
)
96+
]
11097
)
11198

11299

@@ -141,11 +128,31 @@ def code_action(params: lsp.CodeActionParams) -> list[lsp.CodeAction] | None:
141128

142129
if lsp.CodeActionKind.QuickFix in requested_kinds:
143130
text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri)
144-
actions.extend(
145-
make_quickfix_action(diagnostic=diagnostic, text_document=text_document)
146-
for diagnostic in params.context.diagnostics
147-
if diagnostic.source == LSP_SERVER.name
148-
)
131+
our_diagnostics = [
132+
diagnostic for diagnostic in params.context.diagnostics if diagnostic.source == LSP_SERVER.name
133+
]
134+
135+
for diagnostic in our_diagnostics:
136+
data = cattrs.structure(diagnostic.data, DiagnosticData)
137+
actions.append(
138+
lsp.CodeAction(
139+
title=data.fix.message,
140+
kind=lsp.CodeActionKind.QuickFix,
141+
edit=make_workspace_edit(text_document=text_document, text_edits=data.fix.text_edits),
142+
diagnostics=[diagnostic],
143+
)
144+
)
145+
146+
if our_diagnostics:
147+
actions.append(
148+
lsp.CodeAction(
149+
title=f"{LSP_SERVER.name}: Fix All",
150+
kind=lsp.CodeActionKind.QuickFix,
151+
data=params.text_document.uri,
152+
edit=None,
153+
diagnostics=params.context.diagnostics,
154+
)
155+
)
149156

150157
if lsp.CodeActionKind.SourceFixAll in requested_kinds:
151158
actions.append(
@@ -154,7 +161,7 @@ def code_action(params: lsp.CodeActionParams) -> list[lsp.CodeAction] | None:
154161
kind=lsp.CodeActionKind.SourceFixAll,
155162
data=params.text_document.uri,
156163
edit=None,
157-
diagnostics=[],
164+
diagnostics=params.context.diagnostics,
158165
),
159166
)
160167

@@ -163,9 +170,6 @@ def code_action(params: lsp.CodeActionParams) -> list[lsp.CodeAction] | None:
163170

164171
@LSP_SERVER.feature(lsp.CODE_ACTION_RESOLVE)
165172
def resolve_code_action(params: lsp.CodeAction) -> lsp.CodeAction:
166-
if params.kind != lsp.CodeActionKind.SourceFixAll:
167-
return params
168-
169173
text_document = LSP_SERVER.workspace.get_text_document(cast(str, params.data))
170174
params.edit = lsp.WorkspaceEdit(
171175
document_changes=[

0 commit comments

Comments
 (0)