Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data transformation capability to dataset retrieval step #385

Merged
merged 11 commits into from
Jan 15, 2024
2 changes: 2 additions & 0 deletions prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _split_dataset_into_dataset_dict(
datasets.DatasetDict: A dictionary containing the `train`,
`val`, and `test` datasets.
"""
if "train" in dataset:
dataset = dataset["train"]
Comment on lines +134 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but it suggests that this function is not well-typed...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I had to do it because certain ones had a train split that had to be accessed to access samples, whereas some had direct access to samples. either way, I suggest we keep this as it is for now and resolve in a future PR, since that is not the focus for the PR. have tested that this works for now

num_of_examples = len(dataset)
train_num = int(train_proportion * num_of_examples)
val_num = int(val_proportion * num_of_examples)
Expand Down
104 changes: 104 additions & 0 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from prompt2model.dataset_retriever.column_selection_prompt import (
construct_prompt_for_column_selection,
)
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects
from prompt2model.utils.dataset_utils import get_dataset_size
Expand Down Expand Up @@ -218,6 +219,92 @@ def canonicalize_dataset_using_columns(
)
return datasets.DatasetDict(dataset_dict)

def canonicalize_dataset_auto(
self, dataset_name: str, prompt_spec, num_transform: int = 3000
) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format.

Args:
dataset_name: The name of the dataset to canonicalize.

Returns:
A canonicalized dataset.
"""
configs = datasets.get_dataset_config_names(dataset_name)
chosen_config = configs[0]

dataset = datasets.load_dataset(dataset_name, chosen_config).shuffle().flatten()

if "train" not in dataset:
logger.error(f"{dataset_name} must contain a `train` split.")
return None

columns_mapping: dict[str, str] = {}
counter: dict[str, int] = {}
# convert flattened columns like answer.text -> answer_text
for col in dataset["train"].column_names:
new_col = col.replace(".", "_")
if new_col in columns_mapping.values():
counter[new_col] = counter.get(new_col, 0) + 1
new_col = f"{new_col}_{counter[new_col]}"
columns_mapping[col] = new_col
dataset = dataset.rename_columns(columns_mapping)

train_columns = dataset["train"].column_names
train_columns_formatted = ", ".join(train_columns)
dataset_description = dataset["train"].info.description

if len(dataset["train"]) == 0:
raise ValueError("train split is empty.")

example_rows = json.dumps(dataset["train"][0], indent=4)

self._print_divider()
print(f"Loaded dataset. Example rows:\n{example_rows}\n")
logger.info(f"Loaded dataset. Example rows:\n{example_rows}\n")

try:
input_columns, output_column = self.automatic_column_selection(
prompt_spec.instruction,
dataset_name,
dataset_description,
train_columns_formatted,
dataset["train"][0],
)
except RuntimeError:
logger.error(f"{dataset_name} failed at column selection. Try another!")
return None # Returning None means that the dataset chosen didn't work,
# and we would rather generate a dataset.

# remove columns not selected by automatic column selection
dataset = dataset.remove_columns(
[
col_name
for col_name in train_columns
if col_name not in input_columns + [output_column]
]
)
logger.info("Column selection completed")

try:
dataset_transformer = PromptBasedDatasetTransformer()
canonicalized_dataset = dataset_transformer.transform_data(
prompt_spec=prompt_spec,
dataset=dataset["train"],
num_transform=num_transform,
)
except RuntimeError:
logger.error(f"{dataset_name} failed at data transformation. Try another!")
return None
logger.info("Data transformation completed")

example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4)
self._print_divider()
print(f"Transformed dataset. Example rows:\n{example_rows}\n")
logger.info(f"Transformed dataset. Example rows:\n{example_rows}\n")

return canonicalized_dataset

def canonicalize_dataset_by_cli(
self, dataset_name: str, prompt_spec
) -> datasets.DatasetDict:
Expand Down Expand Up @@ -341,16 +428,33 @@ def retrieve_top_datasets(
def retrieve_dataset_dict(
self,
prompt_spec: PromptSpec,
data_transform: bool = False,
num_transform: int = 3000,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.

Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.
data_transform: Whether to transform the dataset or not.
num_transform: Number to transform. ignored if data_transform is False.

Return:
A list of relevant datasets dictionaries.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)
if data_transform:
for dataset in sorted_list:
print(f"Trying {dataset.name}")
canonicalized_dataset = self.canonicalize_dataset_auto(
dataset.name, prompt_spec, num_transform
)
if canonicalized_dataset is not None:
print(f"{dataset.name} successful")
return canonicalized_dataset
print(f"{dataset.name} failed")

return None

top_dataset_name = self.choose_dataset_by_cli(sorted_list)
if top_dataset_name is None:
return None
Expand Down
8 changes: 8 additions & 0 deletions prompt2model/dataset_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Import DatasetGenerator classes."""
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer

__all__ = (
"PromptBasedDatasetTransformer",
"DatasetTransformer",
)
28 changes: 28 additions & 0 deletions prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""An interface for dataset transformation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod

import datasets

from prompt2model.prompt_parser import PromptSpec


class DatasetTransformer(ABC):
"""A class for transforming given dataset to required format."""

@abstractmethod
def transform_data(
self, prompt_spec: PromptSpec, dataset: datasets.Dataset, num_transform: int
) -> datasets.Dataset:
"""Transform a split of data.

Args:
prompt_spec: A prompt spec (containing a system description).
dataset: A dataset split.
num_transform: number of data points you wish to transform.

Returns:
A single dataset split.
"""
123 changes: 123 additions & 0 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""A simple dataset transformer that uses a plan prompt and transform prompt."""

from __future__ import annotations

import asyncio

import datasets

from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
construct_prompt_for_transform_data,
)
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import (
API_ERRORS,
api_tools,
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_json_responses import (
extract_response,
make_request_from_prompt,
)

logger = get_formatted_logger("DatasetTransformer")


class PromptBasedDatasetTransformer(DatasetTransformer):
"""A class for transforming given dataset to required format using a plan and a transform prompt.""" # noqa E501

def __init__(
self,
plan_prompt_fn=construct_prompt_for_plan,
transform_prompt_fn=construct_prompt_for_transform_data,
):
"""Initialize the class."""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan = None

def canonicalize_dataset_using_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format."""
assert len(inputs) > 0
assert len(inputs) == len(outputs)

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
)
return datasets.DatasetDict(dataset_dict)

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset into the required format according to the prompt_spec and dataset.""" # noqa E501
# 1. Use the prompt_spec and an example row from the dataset to create a "plan" for the data transformation. # noqa E501
plan_prompt = self.plan_prompt_fn(
task_description=prompt_spec.instruction,
dataset=dataset,
example=prompt_spec.examples,
)
self.plan = make_request_from_prompt(plan_prompt)

print(f"plan: {self.plan}")

# 2. Use the prompt_spec and the plan to transform each row of the dataset into the required format. # noqa E501
inputs = []
outputs = []

required_keys = ["input", "output"]

max_len = min(num_transform, len(dataset))
len_count = 0
transform_prompts = []
for row in dataset:
transform_prompt = self.transform_prompt_fn(
task_description=prompt_spec.instruction,
dataset_row=row,
example=prompt_spec.examples,
plan=self.plan,
)
transform_prompts.append(transform_prompt)

len_count += 1
if len_count >= max_len:
break

async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
except API_ERRORS as e:
handle_api_error(e)

for response in responses:
try:
extraction = extract_response(response, required_keys, [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")

return self.canonicalize_dataset_using_samples(inputs, outputs)
Loading