Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data transformation capability to dataset retrieval step #385

Merged
merged 11 commits into from
Jan 15, 2024
2 changes: 2 additions & 0 deletions prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _split_dataset_into_dataset_dict(
datasets.DatasetDict: A dictionary containing the `train`,
`val`, and `test` datasets.
"""
if "train" in dataset:
dataset = dataset["train"]
Comment on lines +134 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but it suggests that this function is not well-typed...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I had to do it because certain ones had a train split that had to be accessed to access samples, whereas some had direct access to samples. either way, I suggest we keep this as it is for now and resolve in a future PR, since that is not the focus for the PR. have tested that this works for now

num_of_examples = len(dataset)
train_num = int(train_proportion * num_of_examples)
val_num = int(val_proportion * num_of_examples)
Expand Down
110 changes: 83 additions & 27 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations # noqa FI58

import json
import logging
import os
import random
import urllib.request
Expand All @@ -18,13 +17,14 @@
from prompt2model.dataset_retriever.reranking_prompt import (
construct_prompt_for_dataset_reranking,
)
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects
from prompt2model.utils import encode_text, get_formatted_logger, retrieve_objects
from prompt2model.utils.dataset_utils import get_dataset_size
from prompt2model.utils.parse_responses import parse_prompt_to_fields

datasets.utils.logging.disable_progress_bar()
logger = logging.getLogger(__name__)
logger = get_formatted_logger("DescriptionDatasetRetriever")


class DescriptionDatasetRetriever(DatasetRetriever):
Expand Down Expand Up @@ -206,6 +206,7 @@ def get_all_dataset_infos(self, dataset_list: list[str]) -> dict:
The keys are dataset names and the values are dictionaries
with dataset information.
"""
dataset_info_dict = {}
for dataset_name in dataset_list:
if dataset_name not in self.reranking_datasets_infos:
continue
Expand All @@ -219,11 +220,8 @@ def get_all_dataset_infos(self, dataset_list: list[str]) -> dict:
curr_dataset["configs"] = dict(
random.sample(list(curr_dataset["configs"].items()), 5)
)
dataset_info_dict = {
dataset_name: self.reranking_datasets_infos[dataset_name]
for dataset_name in dataset_list
if dataset_name in self.reranking_datasets_infos
}
dataset_info_dict[dataset_name] = curr_dataset

return dataset_info_dict

@staticmethod
Expand Down Expand Up @@ -312,11 +310,12 @@ def canonicalize_dataset_by_cli(
)
self._print_divider()

dataset_info = self.get_all_dataset_infos([dataset_name])[dataset_name]
dataset_info = self.get_all_dataset_infos([dataset_name])
if len(dataset_info.keys()) == 0:
return None
dataset_info = dataset_info[dataset_name]["configs"][chosen_config]
if dataset_info is None:
return None
dataset_info = dataset_info["configs"][chosen_config]
assert dataset_info is not None
try:
input_columns, output_column = self.automatic_column_selection(
prompt_spec.instruction,
Expand Down Expand Up @@ -437,7 +436,11 @@ def rerank_datasets(self, dataset_list: list[str], prompt_spec: PromptSpec):
return dataset_info_dict[dataset_name]["configs"][config_name]

def canonicalize_dataset_automatically(
self, top_dataset_info: dict, task_instruction: str
self,
top_dataset_info: dict,
prompt_spec: PromptSpec,
auto_transform_data: bool = False,
num_points_to_transform: int = 10,
):
"""Automatically canonicalize dataset (instead of cli).

Expand All @@ -446,18 +449,29 @@ def canonicalize_dataset_automatically(
the top dataset information exists. If so, it proceeds to automatically
select the input and output columns based on the task instruction. The
dataset is then loaded, flattened, and renamed according to the columns
mapping. Finally, the dataset is canonicalized using the selected columns.
mapping. If auto_transform_data is true, num_points_to_transform points
from the dataset are transformed by an LLM to desired format according
to the prompt_spec, and transformed dataset is returned. If
auto_transform_data is false, the dataset is canonicalized using the
selected columns.

Args:
top_dataset_info: Contains info about the top-ranked dataset.
task_instruction: A string representing the instruction for the task,
used to guide column selection.
prompt_spec: prompt object storing the original task and examples.
auto_transform_data: Specifies whether a dataset is to be
transformed. Samples from the original dataset will be transformed
by an LLM to match a desired format as specified by prompt_spec.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. ignored
if data_transform is False.

Returns:
The canonicalized dataset, or None if the dataset is invalid or
if column selection fails, or if any other error occurs
during the process.
"""
task_instruction = prompt_spec.instruction
if top_dataset_info is None:
logger.warning("None of the retrieved datasets were relevant.")
return None
Expand All @@ -472,34 +486,76 @@ def canonicalize_dataset_automatically(
except Exception as e:
logger.warning("Column selection failed: ", e)
return None
full_dataset = datasets.load_dataset(
top_dataset_info["dataset_name"], top_dataset_info["config_name"]
).flatten()
full_dataset = full_dataset.rename_columns(top_dataset_info["columns_mapping"])
canonicalized_dataset = self.canonicalize_dataset_using_columns(
full_dataset, input_columns, output_column
logger.info("Column selection completed")
full_dataset = (
datasets.load_dataset(
top_dataset_info["dataset_name"], top_dataset_info["config_name"]
)
.shuffle()
.flatten()
)
logger.info(f"Using dataset {top_dataset_info['dataset_name']}")
full_dataset = full_dataset.rename_columns(top_dataset_info["columns_mapping"])
logger.info("Dataset loaded")

if auto_transform_data:
# remove columns not selected by automatic column selection
full_dataset = full_dataset.remove_columns(
[
col_name
for col_name in full_dataset["train"].column_names
if col_name not in input_columns + [output_column]
]
)
logger.info("Unnecessary columns removed")

return canonicalized_dataset
dataset_transformer = PromptBasedDatasetTransformer()
canonicalized_dataset = dataset_transformer.transform_data(
prompt_spec=prompt_spec,
dataset=full_dataset["train"],
num_points_to_transform=num_points_to_transform,
)
logger.info("Data transformation completed")

example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4)

logger.info(f"Transformed dataset. Example row:\n{example_rows}\n")

return canonicalized_dataset
else:
canonicalized_dataset = self.canonicalize_dataset_using_columns(
full_dataset, input_columns, output_column
)
logger.info(
f"No transformation. Using dataset {top_dataset_info['dataset_name']}"
) # noqa E501
return canonicalized_dataset

def retrieve_dataset_dict(
self,
prompt_spec: PromptSpec,
auto_transform_data: bool = False,
num_points_to_transform: int = 10,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.

Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.
prompt_spec: prompt object storing the original task and examples.
auto_transform_data: Specifies whether a dataset is to be
transformed. Samples from the original dataset will be transformed
by an LLM to match a desired format as specified by prompt_spec.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. ignored
if data_transform is False.

Return:
The most relevant dataset, canonicalized;
or None if there are no relevant datasets.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)

logger.info(f"Top datasets retrieved. Top datasets: {sorted_list}")
top_dataset_info = self.rerank_datasets(sorted_list, prompt_spec)
print("Datasets Reranked. ")
logger.info(f"Rerank completed. Top dataset info: {top_dataset_info}")
return self.canonicalize_dataset_automatically(
top_dataset_info, prompt_spec.instruction
top_dataset_info, prompt_spec, auto_transform_data, num_points_to_transform
)
8 changes: 8 additions & 0 deletions prompt2model/dataset_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Import DatasetGenerator classes."""
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer

__all__ = (
"PromptBasedDatasetTransformer",
"DatasetTransformer",
)
34 changes: 34 additions & 0 deletions prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""An interface for dataset transformation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod

import datasets

from prompt2model.prompt_parser import PromptSpec


class DatasetTransformer(ABC):
"""A class for transforming given dataset to a desired format."""

@abstractmethod
def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.Dataset:
"""Transform a split of data.

Args:
prompt_spec: A prompt spec (containing a system description).
dataset: A dataset split.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. ignored
if data_transform is False.

Returns:
A single dataset split.
"""
122 changes: 122 additions & 0 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""A simple dataset transformer that uses a plan prompt and transform prompt."""
from __future__ import annotations

import asyncio
from collections.abc import Callable

import datasets

from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
construct_prompt_for_transform_data,
)
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import (
API_ERRORS,
api_tools,
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_responses import make_single_api_request, parse_json

logger = get_formatted_logger("DatasetTransformer")


class PromptBasedDatasetTransformer(DatasetTransformer):
"""Transform data based on a transform prompt."""

def __init__(
self,
plan_prompt_fn: Callable[
[str, list[dict], str], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
] = construct_prompt_for_transform_data,
):
"""Initialize the class."""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""

def make_dataset_from_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Given a list of inputs and outputs, make a dataset."""
if len(inputs) <= 0 or len(inputs) != len(outputs):
raise ValueError("Length of inputs and outputs must be >0 and equal.")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
)
return datasets.DatasetDict(dataset_dict)

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset according to the prompt_spec and dataset."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
dataset,
prompt_spec.examples,
)
self.plan = make_single_api_request(plan_prompt)

logger.info(f"Plan created. Plan: {self.plan}")

inputs = []
outputs = []

required_keys = ["input", "output"]

max_len = min(num_points_to_transform, len(dataset))
len_count = 0
transform_prompts = []
for row in dataset:
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
row,
prompt_spec.examples,
self.plan,
)
transform_prompts.append(transform_prompt)

len_count += 1
if len_count >= max_len:
break

async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
except API_ERRORS as e:
handle_api_error(e)

for response in responses:
try:
extraction = parse_json(response, required_keys, [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")

return self.make_dataset_from_samples(inputs, outputs)
Loading