-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data transformation capability to dataset retrieval step #385
Changes from 6 commits
1a2c403
2aaaf74
0a39035
b69a2f3
054f173
c3516a1
db6a2de
c11fb09
b811479
9b8f7dd
8f8a589
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from prompt2model.dataset_retriever.column_selection_prompt import ( | ||
construct_prompt_for_column_selection, | ||
) | ||
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer | ||
from prompt2model.prompt_parser import PromptSpec | ||
from prompt2model.utils import encode_text, retrieve_objects | ||
from prompt2model.utils.dataset_utils import get_dataset_size | ||
|
@@ -218,6 +219,84 @@ def canonicalize_dataset_using_columns( | |
) | ||
return datasets.DatasetDict(dataset_dict) | ||
|
||
def canonicalize_dataset_auto( | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, dataset_name: str, prompt_spec: PromptSpec, num_transform: int = 3000 | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> datasets.DatasetDict: | ||
"""Canonicalize a dataset into a suitable text-to-text format. | ||
|
||
Args: | ||
dataset_name: The name of the dataset to canonicalize. | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prompt_spec: A prompt whose instruction field we use to transform datasets | ||
num_transform: Number to transform. | ||
|
||
Returns: | ||
A canonicalized dataset. | ||
""" | ||
configs = datasets.get_dataset_config_names(dataset_name) | ||
chosen_config = configs[0] | ||
|
||
dataset = datasets.load_dataset(dataset_name, chosen_config).shuffle().flatten() | ||
|
||
if "train" not in dataset: | ||
raise ValueError("{dataset_name} must contain a `train` split.") | ||
|
||
columns_mapping: dict[str, str] = {} | ||
counter: dict[str, int] = {} | ||
# convert flattened columns like answer.text -> answer_text | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for col in dataset["train"].column_names: | ||
new_col = col.replace(".", "_") | ||
if new_col in columns_mapping.values(): | ||
counter[new_col] = counter.get(new_col, 0) + 1 | ||
new_col = f"{new_col}_{counter[new_col]}" | ||
columns_mapping[col] = new_col | ||
dataset = dataset.rename_columns(columns_mapping) | ||
|
||
train_columns = dataset["train"].column_names | ||
train_columns_formatted = ", ".join(train_columns) | ||
dataset_description = dataset["train"].info.description | ||
|
||
if len(dataset["train"]) == 0: | ||
raise ValueError("train split is empty.") | ||
|
||
example_rows = json.dumps(dataset["train"][0], indent=4) | ||
|
||
self._print_divider() | ||
print(f"Loaded dataset. Example rows:\n{example_rows}\n") | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info(f"Loaded dataset. Example rows:\n{example_rows}\n") | ||
|
||
input_columns, output_column = self.automatic_column_selection( | ||
prompt_spec.instruction, | ||
dataset_name, | ||
dataset_description, | ||
train_columns_formatted, | ||
dataset["train"][0], | ||
) | ||
|
||
# remove columns not selected by automatic column selection | ||
dataset = dataset.remove_columns( | ||
[ | ||
col_name | ||
for col_name in train_columns | ||
if col_name not in input_columns + [output_column] | ||
] | ||
) | ||
logger.info("Column selection completed") | ||
|
||
dataset_transformer = PromptBasedDatasetTransformer() | ||
canonicalized_dataset = dataset_transformer.transform_data( | ||
prompt_spec=prompt_spec, | ||
dataset=dataset["train"], | ||
num_transform=num_transform, | ||
) | ||
logger.info("Data transformation completed") | ||
|
||
example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4) | ||
self._print_divider() | ||
print(f"Transformed dataset. Example rows:\n{example_rows}\n") | ||
logger.info(f"Transformed dataset. Example rows:\n{example_rows}\n") | ||
|
||
return canonicalized_dataset | ||
|
||
def canonicalize_dataset_by_cli( | ||
self, dataset_name: str, prompt_spec | ||
) -> datasets.DatasetDict: | ||
|
@@ -341,16 +420,38 @@ def retrieve_top_datasets( | |
def retrieve_dataset_dict( | ||
self, | ||
prompt_spec: PromptSpec, | ||
data_transform: bool = False, | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_transform: int = 3000, | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> datasets.DatasetDict | None: | ||
"""Select a dataset from a prompt using a dual-encoder retriever. | ||
|
||
Args: | ||
prompt_spec: A prompt whose instruction field we use to retrieve datasets. | ||
data_transform: Whether to transform the dataset or not. | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_transform: Number to transform. ignored if data_transform is False. | ||
|
||
Return: | ||
A list of relevant datasets dictionaries. | ||
""" | ||
sorted_list = self.retrieve_top_datasets(prompt_spec) | ||
if data_transform: | ||
for dataset in sorted_list: | ||
print(f"Trying {dataset.name}") | ||
try: | ||
canonicalized_dataset = self.canonicalize_dataset_auto( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that now we're conflating two concepts:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have merged such that reranking happens separately and colsel+ (if user specifies) transformation separately. we can discuss if any renaming is required in your re-review. |
||
dataset.name, prompt_spec, num_transform | ||
) | ||
except Exception as e: | ||
print(f"{dataset.name} failed") | ||
logger.error(f"{dataset.name} failed with {e}") | ||
continue | ||
|
||
if canonicalized_dataset is not None: | ||
print(f"{dataset.name} successful") | ||
return canonicalized_dataset | ||
|
||
return None | ||
|
||
top_dataset_name = self.choose_dataset_by_cli(sorted_list) | ||
if top_dataset_name is None: | ||
return None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Import DatasetGenerator classes.""" | ||
from prompt2model.dataset_transformer.base import DatasetTransformer | ||
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer | ||
|
||
__all__ = ( | ||
"PromptBasedDatasetTransformer", | ||
"DatasetTransformer", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""An interface for dataset transformation.""" | ||
|
||
from __future__ import annotations # noqa FI58 | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import datasets | ||
|
||
from prompt2model.prompt_parser import PromptSpec | ||
|
||
|
||
class DatasetTransformer(ABC): | ||
"""A class for transforming given dataset to required format.""" | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@abstractmethod | ||
def transform_data( | ||
self, prompt_spec: PromptSpec, dataset: datasets.Dataset, num_transform: int | ||
) -> datasets.Dataset: | ||
"""Transform a split of data. | ||
|
||
Args: | ||
prompt_spec: A prompt spec (containing a system description). | ||
dataset: A dataset split. | ||
num_transform: number of data points you wish to transform. | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
A single dataset split. | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""A simple dataset transformer that uses a plan prompt and transform prompt.""" | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from collections.abc import Callable | ||
|
||
import datasets | ||
|
||
from prompt2model.dataset_transformer.base import DatasetTransformer | ||
from prompt2model.dataset_transformer.prompt_template import ( | ||
construct_prompt_for_plan, | ||
construct_prompt_for_transform_data, | ||
) | ||
from prompt2model.prompt_parser import PromptSpec | ||
from prompt2model.utils import ( | ||
API_ERRORS, | ||
api_tools, | ||
get_formatted_logger, | ||
handle_api_error, | ||
) | ||
from prompt2model.utils.parse_json_responses import ( | ||
extract_response, | ||
make_request_from_prompt, | ||
) | ||
|
||
logger = get_formatted_logger("DatasetTransformer") | ||
|
||
|
||
class PromptBasedDatasetTransformer(DatasetTransformer): | ||
"""Transform data based on a transform prompt.""" | ||
|
||
def __init__( | ||
self, | ||
plan_prompt_fn: Callable[ | ||
[str, list[dict], str], str | ||
] = construct_prompt_for_plan, | ||
transform_prompt_fn: Callable[ | ||
[str, dict, str, str], str | ||
] = construct_prompt_for_transform_data, | ||
): | ||
"""Initialize the class.""" | ||
self.plan_prompt_fn = plan_prompt_fn | ||
self.transform_prompt_fn = transform_prompt_fn | ||
self.plan: str = "" | ||
|
||
def canonicalize_dataset_using_samples( | ||
self, | ||
inputs: list[str], | ||
outputs: list[str], | ||
) -> datasets.DatasetDict: | ||
"""Canonicalize a dataset into a suitable text-to-text format.""" | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(inputs) <= 0 or len(inputs) != len(outputs): | ||
raise ValueError("Length of inputs and outputs must be >0 and equal.") | ||
|
||
dataset_dict = {} | ||
dataset_dict["train"] = datasets.Dataset.from_dict( | ||
{"input_col": inputs, "output_col": outputs} | ||
) | ||
return datasets.DatasetDict(dataset_dict) | ||
|
||
def transform_data( | ||
self, | ||
prompt_spec: PromptSpec, | ||
dataset: datasets.Dataset, | ||
num_transform: int, | ||
) -> datasets.DatasetDict: | ||
"""Transform the dataset according to the prompt_spec and dataset.""" | ||
plan_prompt = self.plan_prompt_fn( | ||
prompt_spec.instruction, | ||
dataset, | ||
prompt_spec.examples, | ||
) | ||
self.plan = make_request_from_prompt(plan_prompt) | ||
|
||
print(f"plan: {self.plan}") | ||
|
||
inputs = [] | ||
outputs = [] | ||
|
||
required_keys = ["input", "output"] | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
max_len = min(num_transform, len(dataset)) | ||
len_count = 0 | ||
transform_prompts = [] | ||
for row in dataset: | ||
transform_prompt = self.transform_prompt_fn( | ||
prompt_spec.instruction, | ||
row, | ||
prompt_spec.examples, | ||
self.plan, | ||
) | ||
transform_prompts.append(transform_prompt) | ||
|
||
len_count += 1 | ||
if len_count >= max_len: | ||
break | ||
|
||
async def generate_responses(transform_prompts): | ||
responses = await api_tools.default_api_agent.generate_batch_completion( | ||
transform_prompts, | ||
temperature=0, | ||
responses_per_request=1, | ||
requests_per_minute=15, | ||
) | ||
return responses | ||
|
||
try: | ||
loop = asyncio.get_event_loop() | ||
responses = loop.run_until_complete(generate_responses(transform_prompts)) | ||
except API_ERRORS as e: | ||
handle_api_error(e) | ||
|
||
for response in responses: | ||
try: | ||
extraction = extract_response(response, required_keys, []) | ||
if extraction is not None: | ||
inputs.append(str(extraction["input"])) | ||
outputs.append(str(extraction["output"])) | ||
except Exception as e: | ||
logger.error(f"Error extracting from response: {response}\nError: {e}") | ||
continue | ||
|
||
logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n") | ||
|
||
return self.canonicalize_dataset_using_samples(inputs, outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, but it suggests that this function is not well-typed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I had to do it because certain ones had a
train
split that had to be accessed to access samples, whereas some had direct access to samples. either way, I suggest we keep this as it is for now and resolve in a future PR, since that is not the focus for the PR. have tested that this works for now