-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data transformation capability to dataset retrieval step #385
Merged
Merged
Changes from 10 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1a2c403
add dataset transformation
saum7800 2aaaf74
add tests
saum7800 0a39035
make PR revisions
saum7800 b69a2f3
merge auto into normal demo
saum7800 054f173
Merge branch 'main' into saumya_data_transform
neubig c3516a1
Merge branch 'main' into saumya_data_transform
saum7800 db6a2de
merge with main
saum7800 c11fb09
merge reranking and transformation flows
saum7800 b811479
update test
saum7800 9b8f7dd
verbose line print in demo
saum7800 8f8a589
minor grammar changes
saum7800 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Import DatasetGenerator classes.""" | ||
from prompt2model.dataset_transformer.base import DatasetTransformer | ||
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer | ||
|
||
__all__ = ( | ||
"PromptBasedDatasetTransformer", | ||
"DatasetTransformer", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""An interface for dataset transformation.""" | ||
|
||
from __future__ import annotations # noqa FI58 | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import datasets | ||
|
||
from prompt2model.prompt_parser import PromptSpec | ||
|
||
|
||
class DatasetTransformer(ABC): | ||
"""A class for transforming given dataset to a desired format.""" | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@abstractmethod | ||
def transform_data( | ||
self, | ||
prompt_spec: PromptSpec, | ||
dataset: datasets.Dataset, | ||
num_points_to_transform: int, | ||
) -> datasets.Dataset: | ||
"""Transform a split of data. | ||
|
||
Args: | ||
prompt_spec: A prompt spec (containing a system description). | ||
dataset: A dataset split. | ||
num_points_to_transform: Number of data points you wish to | ||
transform. Number must be greater than zero. If number is greater | ||
than size of dataset, whole dataset will be transformed. ignored | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if data_transform is False. | ||
|
||
Returns: | ||
A single dataset split. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""A simple dataset transformer that uses a plan prompt and transform prompt.""" | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from collections.abc import Callable | ||
|
||
import datasets | ||
|
||
from prompt2model.dataset_transformer.base import DatasetTransformer | ||
from prompt2model.dataset_transformer.prompt_template import ( | ||
construct_prompt_for_plan, | ||
construct_prompt_for_transform_data, | ||
) | ||
from prompt2model.prompt_parser import PromptSpec | ||
from prompt2model.utils import ( | ||
API_ERRORS, | ||
api_tools, | ||
get_formatted_logger, | ||
handle_api_error, | ||
) | ||
from prompt2model.utils.parse_responses import make_single_api_request, parse_json | ||
|
||
logger = get_formatted_logger("DatasetTransformer") | ||
|
||
|
||
class PromptBasedDatasetTransformer(DatasetTransformer): | ||
"""Transform data based on a transform prompt.""" | ||
|
||
def __init__( | ||
self, | ||
plan_prompt_fn: Callable[ | ||
[str, list[dict], str], str | ||
] = construct_prompt_for_plan, | ||
transform_prompt_fn: Callable[ | ||
[str, dict, str, str], str | ||
] = construct_prompt_for_transform_data, | ||
): | ||
"""Initialize the class.""" | ||
self.plan_prompt_fn = plan_prompt_fn | ||
self.transform_prompt_fn = transform_prompt_fn | ||
self.plan: str = "" | ||
|
||
def make_dataset_from_samples( | ||
self, | ||
inputs: list[str], | ||
outputs: list[str], | ||
) -> datasets.DatasetDict: | ||
"""Given a list of inputs and outputs, make a dataset.""" | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(inputs) <= 0 or len(inputs) != len(outputs): | ||
raise ValueError("Length of inputs and outputs must be >0 and equal.") | ||
|
||
dataset_dict = {} | ||
dataset_dict["train"] = datasets.Dataset.from_dict( | ||
{"input_col": inputs, "output_col": outputs} | ||
) | ||
return datasets.DatasetDict(dataset_dict) | ||
|
||
def transform_data( | ||
self, | ||
prompt_spec: PromptSpec, | ||
dataset: datasets.Dataset, | ||
num_points_to_transform: int, | ||
) -> datasets.DatasetDict: | ||
"""Transform the dataset according to the prompt_spec and dataset.""" | ||
plan_prompt = self.plan_prompt_fn( | ||
prompt_spec.instruction, | ||
dataset, | ||
prompt_spec.examples, | ||
) | ||
self.plan = make_single_api_request(plan_prompt) | ||
|
||
logger.info(f"Plan created. Plan: {self.plan}") | ||
|
||
inputs = [] | ||
outputs = [] | ||
|
||
required_keys = ["input", "output"] | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
max_len = min(num_points_to_transform, len(dataset)) | ||
len_count = 0 | ||
transform_prompts = [] | ||
for row in dataset: | ||
transform_prompt = self.transform_prompt_fn( | ||
prompt_spec.instruction, | ||
row, | ||
prompt_spec.examples, | ||
self.plan, | ||
) | ||
transform_prompts.append(transform_prompt) | ||
|
||
len_count += 1 | ||
if len_count >= max_len: | ||
break | ||
|
||
async def generate_responses(transform_prompts): | ||
responses = await api_tools.default_api_agent.generate_batch_completion( | ||
transform_prompts, | ||
temperature=0, | ||
responses_per_request=1, | ||
requests_per_minute=15, | ||
) | ||
return responses | ||
|
||
try: | ||
loop = asyncio.get_event_loop() | ||
responses = loop.run_until_complete(generate_responses(transform_prompts)) | ||
except API_ERRORS as e: | ||
handle_api_error(e) | ||
|
||
for response in responses: | ||
try: | ||
extraction = parse_json(response, required_keys, []) | ||
if extraction is not None: | ||
inputs.append(str(extraction["input"])) | ||
outputs.append(str(extraction["output"])) | ||
except Exception as e: | ||
logger.error(f"Error extracting from response: {response}\nError: {e}") | ||
continue | ||
|
||
logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n") | ||
|
||
return self.make_dataset_from_samples(inputs, outputs) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, but it suggests that this function is not well-typed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I had to do it because certain ones had a
train
split that had to be accessed to access samples, whereas some had direct access to samples. either way, I suggest we keep this as it is for now and resolve in a future PR, since that is not the focus for the PR. have tested that this works for now