Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data transformation capability to dataset retrieval step #385

Merged
merged 11 commits into from
Jan 15, 2024
2 changes: 2 additions & 0 deletions prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _split_dataset_into_dataset_dict(
datasets.DatasetDict: A dictionary containing the `train`,
`val`, and `test` datasets.
"""
if "train" in dataset:
dataset = dataset["train"]
Comment on lines +134 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but it suggests that this function is not well-typed...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I had to do it because certain ones had a train split that had to be accessed to access samples, whereas some had direct access to samples. either way, I suggest we keep this as it is for now and resolve in a future PR, since that is not the focus for the PR. have tested that this works for now

num_of_examples = len(dataset)
train_num = int(train_proportion * num_of_examples)
val_num = int(val_proportion * num_of_examples)
Expand Down
101 changes: 101 additions & 0 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from prompt2model.dataset_retriever.column_selection_prompt import (
construct_prompt_for_column_selection,
)
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects
from prompt2model.utils.dataset_utils import get_dataset_size
Expand Down Expand Up @@ -218,6 +219,84 @@ def canonicalize_dataset_using_columns(
)
return datasets.DatasetDict(dataset_dict)

def canonicalize_dataset_auto(
self, dataset_name: str, prompt_spec: PromptSpec, num_transform: int = 3000
) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format.

Args:
dataset_name: The name of the dataset to canonicalize.
prompt_spec: A prompt whose instruction field we use to transform datasets
num_transform: Number to transform.

Returns:
A canonicalized dataset.
"""
configs = datasets.get_dataset_config_names(dataset_name)
chosen_config = configs[0]

dataset = datasets.load_dataset(dataset_name, chosen_config).shuffle().flatten()

if "train" not in dataset:
raise ValueError("{dataset_name} must contain a `train` split.")

columns_mapping: dict[str, str] = {}
counter: dict[str, int] = {}
# convert flattened columns like answer.text -> answer_text
for col in dataset["train"].column_names:
new_col = col.replace(".", "_")
if new_col in columns_mapping.values():
counter[new_col] = counter.get(new_col, 0) + 1
new_col = f"{new_col}_{counter[new_col]}"
columns_mapping[col] = new_col
dataset = dataset.rename_columns(columns_mapping)

train_columns = dataset["train"].column_names
train_columns_formatted = ", ".join(train_columns)
dataset_description = dataset["train"].info.description

if len(dataset["train"]) == 0:
raise ValueError("train split is empty.")

example_rows = json.dumps(dataset["train"][0], indent=4)

self._print_divider()
print(f"Loaded dataset. Example rows:\n{example_rows}\n")
logger.info(f"Loaded dataset. Example rows:\n{example_rows}\n")

input_columns, output_column = self.automatic_column_selection(
prompt_spec.instruction,
dataset_name,
dataset_description,
train_columns_formatted,
dataset["train"][0],
)

# remove columns not selected by automatic column selection
dataset = dataset.remove_columns(
[
col_name
for col_name in train_columns
if col_name not in input_columns + [output_column]
]
)
logger.info("Column selection completed")

dataset_transformer = PromptBasedDatasetTransformer()
canonicalized_dataset = dataset_transformer.transform_data(
prompt_spec=prompt_spec,
dataset=dataset["train"],
num_transform=num_transform,
)
logger.info("Data transformation completed")

example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4)
self._print_divider()
print(f"Transformed dataset. Example rows:\n{example_rows}\n")
logger.info(f"Transformed dataset. Example rows:\n{example_rows}\n")

return canonicalized_dataset

def canonicalize_dataset_by_cli(
self, dataset_name: str, prompt_spec
) -> datasets.DatasetDict:
Expand Down Expand Up @@ -341,16 +420,38 @@ def retrieve_top_datasets(
def retrieve_dataset_dict(
self,
prompt_spec: PromptSpec,
data_transform: bool = False,
num_transform: int = 3000,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.

Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.
data_transform: Whether to transform the dataset or not.
num_transform: Number to transform. ignored if data_transform is False.

Return:
A list of relevant datasets dictionaries.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)
if data_transform:
for dataset in sorted_list:
print(f"Trying {dataset.name}")
try:
canonicalized_dataset = self.canonicalize_dataset_auto(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that now we're conflating two concepts:

  • canonicalization (which I thought of as just column selection)
  • transformation (which may make major changes to the dataset's contents, thus not really making it more canonical, according to the definition of that word)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have merged such that reranking happens separately and colsel+ (if user specifies) transformation separately. we can discuss if any renaming is required in your re-review.

dataset.name, prompt_spec, num_transform
)
except Exception as e:
print(f"{dataset.name} failed")
logger.error(f"{dataset.name} failed with {e}")
continue

if canonicalized_dataset is not None:
print(f"{dataset.name} successful")
return canonicalized_dataset

return None

top_dataset_name = self.choose_dataset_by_cli(sorted_list)
if top_dataset_name is None:
return None
Expand Down
8 changes: 8 additions & 0 deletions prompt2model/dataset_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Import DatasetGenerator classes."""
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer

__all__ = (
"PromptBasedDatasetTransformer",
"DatasetTransformer",
)
28 changes: 28 additions & 0 deletions prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""An interface for dataset transformation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod

import datasets

from prompt2model.prompt_parser import PromptSpec


class DatasetTransformer(ABC):
"""A class for transforming given dataset to required format."""

@abstractmethod
def transform_data(
self, prompt_spec: PromptSpec, dataset: datasets.Dataset, num_transform: int
) -> datasets.Dataset:
"""Transform a split of data.

Args:
prompt_spec: A prompt spec (containing a system description).
dataset: A dataset split.
num_transform: number of data points you wish to transform.

Returns:
A single dataset split.
"""
125 changes: 125 additions & 0 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""A simple dataset transformer that uses a plan prompt and transform prompt."""
from __future__ import annotations

import asyncio
from collections.abc import Callable

import datasets

from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
construct_prompt_for_transform_data,
)
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import (
API_ERRORS,
api_tools,
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_json_responses import (
extract_response,
make_request_from_prompt,
)

logger = get_formatted_logger("DatasetTransformer")


class PromptBasedDatasetTransformer(DatasetTransformer):
"""Transform data based on a transform prompt."""

def __init__(
self,
plan_prompt_fn: Callable[
[str, list[dict], str], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
] = construct_prompt_for_transform_data,
):
"""Initialize the class."""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""

def canonicalize_dataset_using_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format."""
if len(inputs) <= 0 or len(inputs) != len(outputs):
raise ValueError("Length of inputs and outputs must be >0 and equal.")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
)
return datasets.DatasetDict(dataset_dict)

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset according to the prompt_spec and dataset."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
dataset,
prompt_spec.examples,
)
self.plan = make_request_from_prompt(plan_prompt)

print(f"plan: {self.plan}")

inputs = []
outputs = []

required_keys = ["input", "output"]

max_len = min(num_transform, len(dataset))
len_count = 0
transform_prompts = []
for row in dataset:
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
row,
prompt_spec.examples,
self.plan,
)
transform_prompts.append(transform_prompt)

len_count += 1
if len_count >= max_len:
break

async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
except API_ERRORS as e:
handle_api_error(e)

for response in responses:
try:
extraction = extract_response(response, required_keys, [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")

return self.canonicalize_dataset_using_samples(inputs, outputs)
Loading