Skip to content

Commit

Permalink
Update prompt_based.py to fix the length issues in the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBK authored Dec 5, 2023
1 parent 82f954a commit d0cd15c
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions prompt2model/dataset_generator/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ class Example:

def __eq__(self, other) -> bool:
"""Example equality."""
return self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col
return (self.input_col == other.input_col and
self.output_col == other.output_col and
self.explain_col == other.explain_col)

def __lt__(self, other) -> bool:
"""Example less than."""
return self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col
return (self.input_col < other.input_col
or self.output_col < other.output_col
or self.explain_col < other.explain_col)


class PromptBasedDatasetGenerator(DatasetGenerator):
Expand Down Expand Up @@ -170,7 +174,9 @@ def construct_prompt(
)
for example in random_examples:
low_quality_example_string += (
f'input="{example.input_col}"\nexplanation="{example.explain_col}"\noutput="{example.output_col}"\n'
f'input="{example.input_col}"\n'
f'explanation="{example.explain_col}"\n'
f'output="{example.output_col}"\n'
)
# To increase the diversity of the prompt to DatasetGenerator, create three
# prompt templates, COMPLEX, MIDDLE, and SIMPLE. The COMPLEX template
Expand Down Expand Up @@ -255,7 +261,13 @@ def apply_multi_vote_filtering(
most_frequent_outputs.sort(key=len)
final_output = most_frequent_outputs[0]

filtered_examples.append(Example(input_str, random.choice(output_explain_map[final_output]),final_output))
filtered_examples.append(
Example(
input_str,
random.choice(output_explain_map[final_output]),
final_output
)
)
return filtered_examples

def compute_batch_size(self, num_examples: int, generated_dataset_size: int) -> int:
Expand Down

0 comments on commit d0cd15c

Please sign in to comment.