Skip to content

Commit aa7e980

Browse files
authored
Handle chained assignments (#34)
1 parent b8a9e19 commit aa7e980

File tree

3 files changed

+48
-33
lines changed

3 files changed

+48
-33
lines changed

Justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ check-types:
1212
uv -q run mypy .
1313

1414
test *args:
15-
@.venv/bin/pytest -- {{ args }}
15+
@.venv/bin/pytest {{ args }}
1616

1717
publish-package:
1818
rm -rf dist/*

auto_typing_final/transform.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ class ImportConfig:
2727

2828

2929
@dataclass
30-
class AssignmentWithoutAnnotation:
30+
class EditableAssignmentWithoutAnnotation:
3131
node: SgNode
3232
left: str
3333
right: str
3434

3535

3636
@dataclass
37-
class AssignmentWithAnnotation:
37+
class EditableAssignmentWithAnnotation:
3838
node: SgNode
3939
left: str
4040
annotation: SgNode
@@ -46,12 +46,12 @@ class OtherDefinition:
4646
node: SgNode
4747

4848

49-
Definition = AssignmentWithoutAnnotation | AssignmentWithAnnotation | OtherDefinition
49+
Definition = EditableAssignmentWithoutAnnotation | EditableAssignmentWithAnnotation | OtherDefinition
5050

5151

5252
@dataclass
5353
class AddFinal:
54-
definition: Definition
54+
node: Definition
5555

5656

5757
@dataclass
@@ -62,37 +62,48 @@ class RemoveFinal:
6262
Operation = AddFinal | RemoveFinal
6363

6464

65-
def _make_operation_from_assignments_to_one_name(nodes: list[SgNode]) -> Operation:
66-
value_assignments: Final[list[Definition]] = []
65+
def _make_definition_from_definition_node(node: SgNode) -> Definition:
66+
if node.kind() != "assignment":
67+
return OtherDefinition(node)
68+
69+
match tuple((child.kind(), child) for child in node.children()):
70+
case (
71+
("identifier", left),
72+
("=", _),
73+
(right_kind, right),
74+
) if right_kind != "assignment" and not ((parent := node.parent()) and parent.kind() == "assignment"):
75+
return EditableAssignmentWithoutAnnotation(node=node, left=left.text(), right=right.text())
76+
case (
77+
("identifier", left),
78+
(":", _),
79+
("type", annotation),
80+
("=", _),
81+
(right_kind, right),
82+
) if right_kind != "assignment" and not ((parent := node.parent()) and parent.kind() == "assignment"):
83+
return EditableAssignmentWithAnnotation(
84+
node=node, left=left.text(), annotation=annotation, right=right.text()
85+
)
86+
case _:
87+
return OtherDefinition(node)
88+
89+
90+
def _make_operation_from_definitions_of_one_name(nodes: list[SgNode]) -> Operation:
91+
value_definitions: Final[list[Definition]] = []
6792
has_node_inside_loop = False
6893

6994
for node in nodes:
7095
if any(ancestor.kind() in {"for_statement", "while_statement"} for ancestor in node.ancestors()):
7196
has_node_inside_loop = True
72-
73-
if node.kind() == "assignment":
74-
match tuple((child.kind(), child) for child in node.children()):
75-
case (("identifier", left), ("=", _), (_, right)):
76-
value_assignments.append(
77-
AssignmentWithoutAnnotation(node=node, left=left.text(), right=right.text())
78-
)
79-
case (("identifier", left), (":", _), ("type", annotation), ("=", _), (_, right)):
80-
value_assignments.append(
81-
AssignmentWithAnnotation(node=node, left=left.text(), annotation=annotation, right=right.text())
82-
)
83-
case _:
84-
value_assignments.append(OtherDefinition(node))
85-
else:
86-
value_assignments.append(OtherDefinition(node))
97+
value_definitions.append(_make_definition_from_definition_node(node))
8798

8899
if has_node_inside_loop:
89-
return RemoveFinal(value_assignments)
100+
return RemoveFinal(value_definitions)
90101

91-
match value_assignments:
92-
case [assignment]:
93-
return AddFinal(assignment)
94-
case assignments:
95-
return RemoveFinal(assignments)
102+
match value_definitions:
103+
case [definition]:
104+
return AddFinal(definition)
105+
case definitions:
106+
return RemoveFinal(definitions)
96107

97108

98109
def _match_exact_identifier(node: SgNode, imports_result: ImportsResult, identifier_name: str) -> str | None:
@@ -141,9 +152,9 @@ def _make_changed_text_from_operation( # noqa: C901
141152
match operation:
142153
case AddFinal(assignment):
143154
match assignment:
144-
case AssignmentWithoutAnnotation(node, left, right):
155+
case EditableAssignmentWithoutAnnotation(node, left, right):
145156
yield node, f"{left}: {final_value} = {right}"
146-
case AssignmentWithAnnotation(node, left, annotation, right):
157+
case EditableAssignmentWithAnnotation(node, left, annotation, right):
147158
match _strip_identifier_from_type_annotation(annotation, imports_result, identifier_name):
148159
case None:
149160
yield node, f"{left}: {final_value}[{annotation.text()}] = {right}"
@@ -155,9 +166,9 @@ def _make_changed_text_from_operation( # noqa: C901
155166
case RemoveFinal(assignments):
156167
for assignment in assignments:
157168
match assignment:
158-
case AssignmentWithoutAnnotation(node, left, right):
169+
case EditableAssignmentWithoutAnnotation(node, left, right):
159170
yield node, f"{left} = {right}"
160-
case AssignmentWithAnnotation(node, left, annotation, right):
171+
case EditableAssignmentWithAnnotation(node, left, annotation, right):
161172
match _strip_identifier_from_type_annotation(annotation, imports_result, identifier_name):
162173
case _, "":
163174
yield node, f"{left} = {right}"
@@ -189,7 +200,7 @@ def make_replacements(root: SgNode, import_config: ImportConfig) -> MakeReplacem
189200
imports_result: Final = find_imports_of_identifier_in_scope(root, module_name="typing", identifier_name="Final")
190201

191202
for current_definitions in find_all_definitions_in_functions(root):
192-
operation = _make_operation_from_assignments_to_one_name(current_definitions)
203+
operation = _make_operation_from_definitions_of_one_name(current_definitions)
193204
edits = [
194205
Edit(node=node, new_text=new_text)
195206
for node, new_text in _make_changed_text_from_operation(

tests/test_main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,16 @@
3030
("(a, b) = t()", "(a, b) = t()"),
3131
("[a, b] = t()", "[a, b] = t()"),
3232
("[a] = t()", "[a] = t()"),
33+
("a = b = 1", "a = b = 1"),
34+
("a = b = c = 1", "a = b = c = 1"),
35+
("a = (b := 1)", "a: {} = (b := 1)"),
3336
# Remove annotation
3437
("a = 1\na: {}[int] = 2", "a = 1\na: int = 2"),
3538
("a = 1\na: {} = 2", "a = 1\na = 2"),
3639
("a: int = 1\na: {}[int] = 2", "a: int = 1\na: int = 2"),
3740
("a: int = 1\na: {} = 2", "a: int = 1\na = 2"),
3841
("a: {} = 1\na: {} = 2\na = 3\na: int = 4", "a = 1\na = 2\na = 3\na: int = 4"),
42+
("a: {} = b = 1", "a: {} = b = 1"),
3943
# Both
4044
("a = 1\nb = 2\nb: {}[int] = 3", "a: {} = 1\nb = 2\nb: int = 3"),
4145
],

0 commit comments

Comments
 (0)