@@ -785,9 +785,9 @@ def _get_avg_pool2d_replacement(
785785 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
786786)
787787def _get_quantize_per_tensor_replacement (
788- node : Node , exported_program : ExportedProgram
788+ node : Node , dialect_pass : AtenToDialectPass
789789) -> DialectNodeSpec | None :
790- del exported_program
790+ del dialect_pass
791791 if not _is_quant_per_tensor_qualified (node ):
792792 return None
793793 return DialectNodeSpec (
@@ -799,9 +799,9 @@ def _get_quantize_per_tensor_replacement(
799799 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
800800)
801801def _get_dequantize_per_tensor_replacement (
802- node : Node , exported_program : ExportedProgram
802+ node : Node , dialect_pass : AtenToDialectPass
803803) -> DialectNodeSpec | None :
804- del exported_program
804+ del dialect_pass
805805 if not _is_quant_per_tensor_qualified (node ):
806806 return None
807807 return DialectNodeSpec (
@@ -811,9 +811,9 @@ def _get_dequantize_per_tensor_replacement(
811811
812812@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten .add .Tensor )
813813def _get_add_replacement (
814- node : Node , exported_program : ExportedProgram
814+ node : Node , dialect_pass : AtenToDialectPass
815815) -> DialectNodeSpec | None :
816- del exported_program
816+ del dialect_pass
817817 if not _has_qparams (node ):
818818 return None
819819
@@ -854,9 +854,9 @@ def _get_add_replacement(
854854
855855@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten .mul .Tensor )
856856def _get_mul_replacement (
857- node : Node , exported_program : ExportedProgram
857+ node : Node , dialect_pass : AtenToDialectPass
858858) -> DialectNodeSpec | None :
859- del exported_program
859+ del dialect_pass
860860 if not _has_qparams (node ):
861861 return None
862862
@@ -884,9 +884,9 @@ def _get_mul_replacement(
884884
885885@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten ._softmax .default )
886886def _get_softmax_replacement (
887- node : Node , exported_program : ExportedProgram
887+ node : Node , dialect_pass : AtenToDialectPass
888888) -> DialectNodeSpec | None :
889- del exported_program
889+ del dialect_pass
890890 if not _has_qparams (node ):
891891 return None
892892
@@ -934,9 +934,9 @@ def _get_softmax_replacement(
934934
935935@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten .max_pool2d .default )
936936def _get_max_pool2d_replacement (
937- node : Node , exported_program : ExportedProgram
937+ node : Node , dialect_pass : AtenToDialectPass
938938) -> DialectNodeSpec | None :
939- del exported_program
939+ del dialect_pass
940940 input_qparams = node .meta .get ("input_qparams" , {}).get (0 )
941941 cortex_m_meta = node .meta .get ("custom" , {}).get ("cortex_m" , {})
942942 if input_qparams is None or cortex_m_meta .get ("skip_quantized_max_pool2d" , False ):
@@ -999,9 +999,9 @@ def _get_max_pool2d_replacement(
999999
10001000@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten .minimum .default )
10011001def _get_minimum_replacement (
1002- node : Node , exported_program : ExportedProgram
1002+ node : Node , dialect_pass : AtenToDialectPass
10031003) -> DialectNodeSpec | None :
1004- del exported_program
1004+ del dialect_pass
10051005 input_tensor = _get_input_tensor_data (node )
10061006 if input_tensor .dtype not in (torch .int8 , torch .int32 ):
10071007 return None
@@ -1010,9 +1010,9 @@ def _get_minimum_replacement(
10101010
10111011@AtenToCortexMPass .register_dialect_substitution (exir_ops .edge .aten .maximum .default )
10121012def _get_maximum_replacement (
1013- node : Node , exported_program : ExportedProgram
1013+ node : Node , dialect_pass : AtenToDialectPass
10141014) -> DialectNodeSpec | None :
1015- del exported_program
1015+ del dialect_pass
10161016 input_tensor = _get_input_tensor_data (node )
10171017 if input_tensor .dtype != torch .int8 :
10181018 return None
@@ -1023,9 +1023,9 @@ def _get_maximum_replacement(
10231023 exir_ops .edge .aten .permute_copy .default
10241024)
10251025def _get_permute_replacement (
1026- node : Node , exported_program : ExportedProgram
1026+ node : Node , dialect_pass : AtenToDialectPass
10271027) -> DialectNodeSpec | None :
1028- del exported_program
1028+ del dialect_pass
10291029 input_tensor = _get_input_tensor_data (node )
10301030 if input_tensor .dtype != torch .int8 :
10311031 return None
@@ -1041,9 +1041,9 @@ def _get_permute_replacement(
10411041 exir_ops .edge .aten .constant_pad_nd .default
10421042)
10431043def _get_pad_replacement (
1044- node : Node , exported_program : ExportedProgram
1044+ node : Node , dialect_pass : AtenToDialectPass
10451045) -> DialectNodeSpec | None :
1046- del exported_program
1046+ del dialect_pass
10471047 input_qparams = node .meta .get ("input_qparams" , {})
10481048 if not input_qparams :
10491049 return None
0 commit comments