Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset Reranking #383

Merged
merged 24 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
ritugala marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

17 changes: 1 addition & 16 deletions prompt2model/dataset_retriever/column_selection_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations # noqa FI58

import json

METAPROMPT_BASE = """Your objective is to carefully analyze the task and the dataset mentioned, and decide whether the columns are relevant input, relevant output, irrelevant for the given task, or if it is ambiguous. There should be at most one output column. It is possible to have no relevant columns, in which case return the input and output column as empty lists. Answer in a json format, with the following keys: input, output, irrelevant, ambiguous""" # noqa: E501
METAPROMPT_EXAMPLES = [
(
Expand Down Expand Up @@ -90,19 +88,6 @@
ENDING_LINE = "After seeing these examples with the required columns, please provide the relevant columns for this context:" # noqa: E501


def truncate_row(example_row: dict, max_length=50) -> str:
"""Truncate the row before displaying if it is too long."""
truncated_row = {}
for key in example_row.keys():
curr_row = json.dumps(example_row[key])
truncated_row[key] = (
curr_row
if len(curr_row) <= max_length - 3
else curr_row[:max_length] + "..."
)
return json.dumps(truncated_row)


def build_input(
instruction: str,
dataset_name: str,
Expand All @@ -116,7 +101,7 @@ def build_input(
dataset_name=dataset_name,
dataset_description=dataset_description,
dataset_columns=dataset_columns,
sample_row=truncate_row(sample_row),
sample_row=sample_row,
)
input_prompt = SINGLE_DEMONSTRATION_TEMPLATE.format(
prompt=input_prompt, columns=""
Expand Down
ritugala marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions prompt2model/dataset_retriever/dataset_index_file/preprocessing.py
ritugala marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Filtering out datasets before we do heavy processing on them."""
ritugala marked this conversation as resolved.
Show resolved Hide resolved
import json
from typing import Any

from huggingface_hub import list_datasets

# Constants
ritugala marked this conversation as resolved.
Show resolved Hide resolved
ALL_DATASETS_FILE = "all_datasets.json"
FILTERED_DATASETS_FILE = "filtered_datasets.json"
MIN_WORDS_IN_DESC = 4
MIN_DOWNLOADS = 10


def load_datasets(file_path: str, is_first_time=False) -> list[dict[str, Any]]:
ritugala marked this conversation as resolved.
Show resolved Hide resolved
"""Load datasets from a JSON file."""
ritugala marked this conversation as resolved.
Show resolved Hide resolved
if is_first_time:
all_datasets = list(list_datasets())
with open(ALL_DATASETS_FILE, "w") as f:
ds = json.dumps([ob.__dict__ for ob in all_datasets])
f.write(ds)
return all_datasets
else:
with open(file_path, "r") as file:
return json.load(file)


def filter_datasets(datasets: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Filter datasets based on specific criteria."""
ritugala marked this conversation as resolved.
Show resolved Hide resolved
filtered_datasets = []
descr_none = descr_small = downloads_less = common_descr = 0
unique_descriptions: set[str] = set()

for dataset_info in datasets:
description = dataset_info.get("description")

if not description:
descr_none += 1
continue
if len(description.split()) < MIN_WORDS_IN_DESC:
descr_small += 1
continue
if dataset_info.get("downloads", 0) < MIN_DOWNLOADS:
downloads_less += 1
continue
if description in unique_descriptions:
common_descr += 1
continue

filtered_datasets.append(dataset_info)
unique_descriptions.add(description)

print(
f"descr_none: {descr_none}, descr_small: {descr_small}, "
f"downloads_less: {downloads_less}, common_descr: {common_descr}"
)
ritugala marked this conversation as resolved.
Show resolved Hide resolved

return filtered_datasets


def main():
"""Main function to load and filter datasets."""
all_datasets = load_datasets(ALL_DATASETS_FILE)
filtered_datasets = filter_datasets(all_datasets)
with open(FILTERED_DATASETS_FILE, "w") as f:
json.dump(filtered_datasets, f)


if __name__ == "__main__":
main()
ritugala marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
ritugala marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
from __future__ import annotations # noqa FI58

import argparse
import gc
import json
import multiprocessing
import sys
import threading
import time
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any

import datasets
import requests

EXCLUDED_TAGS = [
ritugala marked this conversation as resolved.
Show resolved Hide resolved
"arxiv",
"region",
"license",
"size_categories",
"language_creators",
]

parser = argparse.ArgumentParser()
parser.add_argument("--index", type=int, default=100)

parser.add_argument("--num_processes", type=int, default=4)
ritugala marked this conversation as resolved.
Show resolved Hide resolved


def get_dataset_validity(dataset_name, max_retries=5):
ritugala marked this conversation as resolved.
Show resolved Hide resolved
"""Get the list of loadable datasets from HuggingFace with backoff strategy."""
ritugala marked this conversation as resolved.
Show resolved Hide resolved
API_URL = f"https://datasets-server.huggingface.co/is-valid?dataset={dataset_name}"
ritugala marked this conversation as resolved.
Show resolved Hide resolved
retries = 0
backoff = 10

while retries < max_retries:
response = requests.get(API_URL)

if response.status_code == 200:
response = response.json()
return (
"preview" in response
and "viewer" in response
and response["preview"] & response["viewer"]
)

elif response.status_code == 429:
retry_after = response.headers.get("Retry-After")
wait = int(retry_after) if retry_after else backoff
time.sleep(wait)
backoff *= 2 # Exponential increase
retries += 1
else:
# Handle other HTTP errors
break

return False


def replace_duplicate_columns(original_dataset_columns):
"""Utility function to remove duplicate columns, after flattening dataset."""
columns_mapping: dict[str, str] = {}
new_columns = []
counter: dict[str, int] = {}
# convert flattened columns like answer.text -> answer_text
for col in original_dataset_columns:
new_col = col.replace(".", "_")
if new_col in columns_mapping.values():
counter[new_col] = counter.get(new_col, 0) + 1
new_col = f"{new_col}_{counter[new_col]}"
columns_mapping[col] = new_col
new_columns.append(new_col)
return new_columns, columns_mapping


def fetch_first_row_with_timeout(dataset, timeout=30):
"""Don't load dataset if it takes more than 30s."""

def fetch_sample_row(container):
try:
container.append(next(iter(dataset)))
except Exception as e:
container.append(e)

result_container = []
fetch_thread = threading.Thread(target=fetch_sample_row, args=(result_container,))
fetch_thread.start()
fetch_thread.join(timeout)

if fetch_thread.is_alive():
# Operation took too long
return None

return result_container[0]


def truncate_row(example_row: dict, max_length=50) -> str:
"""Truncate the row before displaying if it is too long."""
truncated_row = {}
for key in example_row.keys():
curr_row = json.dumps(example_row[key])
truncated_row[key] = (
curr_row
if len(curr_row) <= max_length - 3
else curr_row[:max_length] + "..."
)
return json.dumps(truncated_row)


def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> dict:
"""Utility function to flatten Streaming HuggingFace datasets."""
items: list[tuple[str, Any]] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)


def get_dataset_configs(dataset_name):
"""Get all valid configs for a given dataset."""
config_names = datasets.get_dataset_config_names(dataset_name)
all_configs = {}
for config_name in config_names:
if "train" not in datasets.get_dataset_split_names(dataset_name, config_name):
continue
dataset = datasets.load_dataset(
dataset_name,
config_name,
split="train",
streaming=True,
download_mode="force_redownload",
)
sample_rows = fetch_first_row_with_timeout(dataset, timeout=30)
if not sample_rows:
raise ValueError("no sample rows")
sample_rows = flatten_dict(sample_rows)
if any(
"ImageFile" in sample_rows[key].__class__.__name__
or "DateTime" in sample_rows[key].__class__.__name__
for key in sample_rows
):
raise ValueError("Image File")
columns, columns_mapping = replace_duplicate_columns(sample_rows.keys())

columns = ", ".join(columns)
all_configs[config_name] = {
"config_name": config_name,
"sample_row": truncate_row(sample_rows),
"columns": columns,
"columns_mapping": columns_mapping,
"dataset_description": dataset.info.description,
"dataset_name": dataset_name,
}

return all_configs


def process_datasets(chunk, process_index):
"""Process through the chunk of datasets and get dataset info to store."""
dataset_index = {}
max_attempts = 3
for z in range(len(chunk)):
print(f"Process index: {process_index} : currently on {z} out of {len(chunk)}")

attempt = 0
delay = 10
while attempt < max_attempts:
try:
dataset_info = chunk[z]
dataset_name = dataset_info["id"]
description = dataset_info["description"]

is_gated = hasattr(dataset_info, "gated") and dataset_info["gated"]
if hasattr(dataset_info, "disabled") and dataset_info["disabled"]:
raise ValueError("dataset is disabled")
if not get_dataset_validity(dataset_name):
raise ValueError("dataset is not valid")

all_configs = get_dataset_configs(dataset_name)

filtered_tags = [
tag
for tag in dataset_info["tags"]
if not any(excluded_word in tag for excluded_word in EXCLUDED_TAGS)
]
dataset_index[dataset_name] = {
"dataset_name": dataset_name,
"description": description,
"downloads": dataset_info["downloads"],
"configs": all_configs,
"tags": filtered_tags,
"is_gated": is_gated,
}
print(
f"""completed {z} out of {len(chunk)}, dataset is
{dataset_name}, and it has {len(all_configs)} configs in it"""
)
del all_configs, filtered_tags

break
except Exception as e:
if "429 Client Error" in str(e):
time.sleep(delay)
delay *= 2
attempt += 1
else:
print("Error processing +", dataset_info["id"], ": ", e)
break
except SystemExit as e:
print("Error processing +", dataset_info["id"], ": ", e)
break
gc.collect()

return dataset_index


def worker(chunk, index, temp_folder):
"""Utility function for Multiprocessing."""
try:
result = process_datasets(chunk, index)
temp_file = temp_folder / f"temp_{index}.json"
with open(temp_file, "w") as f:
json.dump(result, f)
except: # noqa: E722
e = sys.exc_info()[0]

print(f"Process {index} died because of {e}.") # noqa: E501


def chunkify(lst, n):
"""Divide the input list into n chunks."""
for i in range(0, len(lst), n):
yield lst[i : i + n]


if __name__ == "__main__":
start_time = time.time()
args = parser.parse_args()
all_datasets_file = "processed_datasets.json"

with open(all_datasets_file, "r") as f:
all_datasets = json.load(f)

# Split the dataset into num_processes chunks
chunk_size = len(all_datasets) // args.num_processes
chunks = list(chunkify(all_datasets, chunk_size))
temp_folder = Path("temp_data_" + str(args.index))
temp_folder.mkdir(exist_ok=True)
final_folder = Path("final_folder")
final_folder.mkdir(exist_ok=True)
output_file = final_folder / f"final_{args.index}.json"

# Setup multiprocessing
processes = []
for i, chunk in enumerate(chunks):
p = multiprocessing.Process(target=worker, args=(chunk, i, temp_folder))
processes.append(p)
p.start()

for p in processes:
p.join()

# Combine results from temp files
dataset_index = {}
for temp_file in temp_folder.glob("temp_*.json"):
with open(temp_file, "r", encoding="utf-8") as f:
dataset_index.update(json.load(f))

# Write the final result
with open(output_file, "w+") as f:
json.dump(dataset_index, f)

# Optional: clean up temp files
for temp_file in temp_folder.glob("temp_*.json"):
temp_file.unlink()
temp_folder.rmdir()

end_time = time.time()
print(
f"Process took {end_time-start_time} seconds, {(end_time-start_time)/60} mins"
)

Large diffs are not rendered by default.

Loading