diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 1f3707280..db2824446 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -40,14 +40,14 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" - return (self.input_col == other.input_col and - self.output_col == other.output_col and + return (self.input_col == other.input_col and + self.output_col == other.output_col and self.explain_col == other.explain_col) def __lt__(self, other) -> bool: """Example less than.""" - return (self.input_col < other.input_col - or self.output_col < other.output_col + return (self.input_col < other.input_col + or self.output_col < other.output_col or self.explain_col < other.explain_col) @@ -263,8 +263,8 @@ def apply_multi_vote_filtering( filtered_examples.append( Example( - input_str, - random.choice(output_explain_map[final_output]), + input_str, + random.choice(output_explain_map[final_output]), final_output ) )