Skip to content

Commit dd00104

Browse files
committed
Mypy fixes
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I9f94f018d95ec58a4a12022679ddd66340344fa0
1 parent 4fb243a commit dd00104

1 file changed

Lines changed: 20 additions & 20 deletions

File tree

backends/cortex_m/passes/aten_to_cortex_m_pass.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,9 @@ def _get_avg_pool2d_replacement(
785785
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
786786
)
787787
def _get_quantize_per_tensor_replacement(
788-
node: Node, exported_program: ExportedProgram
788+
node: Node, dialect_pass: AtenToDialectPass
789789
) -> DialectNodeSpec | None:
790-
del exported_program
790+
del dialect_pass
791791
if not _is_quant_per_tensor_qualified(node):
792792
return None
793793
return DialectNodeSpec(
@@ -799,9 +799,9 @@ def _get_quantize_per_tensor_replacement(
799799
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
800800
)
801801
def _get_dequantize_per_tensor_replacement(
802-
node: Node, exported_program: ExportedProgram
802+
node: Node, dialect_pass: AtenToDialectPass
803803
) -> DialectNodeSpec | None:
804-
del exported_program
804+
del dialect_pass
805805
if not _is_quant_per_tensor_qualified(node):
806806
return None
807807
return DialectNodeSpec(
@@ -811,9 +811,9 @@ def _get_dequantize_per_tensor_replacement(
811811

812812
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.add.Tensor)
813813
def _get_add_replacement(
814-
node: Node, exported_program: ExportedProgram
814+
node: Node, dialect_pass: AtenToDialectPass
815815
) -> DialectNodeSpec | None:
816-
del exported_program
816+
del dialect_pass
817817
if not _has_qparams(node):
818818
return None
819819

@@ -854,9 +854,9 @@ def _get_add_replacement(
854854

855855
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.mul.Tensor)
856856
def _get_mul_replacement(
857-
node: Node, exported_program: ExportedProgram
857+
node: Node, dialect_pass: AtenToDialectPass
858858
) -> DialectNodeSpec | None:
859-
del exported_program
859+
del dialect_pass
860860
if not _has_qparams(node):
861861
return None
862862

@@ -884,9 +884,9 @@ def _get_mul_replacement(
884884

885885
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten._softmax.default)
886886
def _get_softmax_replacement(
887-
node: Node, exported_program: ExportedProgram
887+
node: Node, dialect_pass: AtenToDialectPass
888888
) -> DialectNodeSpec | None:
889-
del exported_program
889+
del dialect_pass
890890
if not _has_qparams(node):
891891
return None
892892

@@ -934,9 +934,9 @@ def _get_softmax_replacement(
934934

935935
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.max_pool2d.default)
936936
def _get_max_pool2d_replacement(
937-
node: Node, exported_program: ExportedProgram
937+
node: Node, dialect_pass: AtenToDialectPass
938938
) -> DialectNodeSpec | None:
939-
del exported_program
939+
del dialect_pass
940940
input_qparams = node.meta.get("input_qparams", {}).get(0)
941941
cortex_m_meta = node.meta.get("custom", {}).get("cortex_m", {})
942942
if input_qparams is None or cortex_m_meta.get("skip_quantized_max_pool2d", False):
@@ -999,9 +999,9 @@ def _get_max_pool2d_replacement(
999999

10001000
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.minimum.default)
10011001
def _get_minimum_replacement(
1002-
node: Node, exported_program: ExportedProgram
1002+
node: Node, dialect_pass: AtenToDialectPass
10031003
) -> DialectNodeSpec | None:
1004-
del exported_program
1004+
del dialect_pass
10051005
input_tensor = _get_input_tensor_data(node)
10061006
if input_tensor.dtype not in (torch.int8, torch.int32):
10071007
return None
@@ -1010,9 +1010,9 @@ def _get_minimum_replacement(
10101010

10111011
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.maximum.default)
10121012
def _get_maximum_replacement(
1013-
node: Node, exported_program: ExportedProgram
1013+
node: Node, dialect_pass: AtenToDialectPass
10141014
) -> DialectNodeSpec | None:
1015-
del exported_program
1015+
del dialect_pass
10161016
input_tensor = _get_input_tensor_data(node)
10171017
if input_tensor.dtype != torch.int8:
10181018
return None
@@ -1023,9 +1023,9 @@ def _get_maximum_replacement(
10231023
exir_ops.edge.aten.permute_copy.default
10241024
)
10251025
def _get_permute_replacement(
1026-
node: Node, exported_program: ExportedProgram
1026+
node: Node, dialect_pass: AtenToDialectPass
10271027
) -> DialectNodeSpec | None:
1028-
del exported_program
1028+
del dialect_pass
10291029
input_tensor = _get_input_tensor_data(node)
10301030
if input_tensor.dtype != torch.int8:
10311031
return None
@@ -1041,9 +1041,9 @@ def _get_permute_replacement(
10411041
exir_ops.edge.aten.constant_pad_nd.default
10421042
)
10431043
def _get_pad_replacement(
1044-
node: Node, exported_program: ExportedProgram
1044+
node: Node, dialect_pass: AtenToDialectPass
10451045
) -> DialectNodeSpec | None:
1046-
del exported_program
1046+
del dialect_pass
10471047
input_qparams = node.meta.get("input_qparams", {})
10481048
if not input_qparams:
10491049
return None

0 commit comments

Comments
 (0)