diff --git a/src/instructlab/sdg/checkpointing.py b/src/instructlab/sdg/checkpointing.py new file mode 100644 index 00000000..9095c557 --- /dev/null +++ b/src/instructlab/sdg/checkpointing.py @@ -0,0 +1,86 @@ +# Standard +import logging +import uuid + +# Third Party +from datasets import Dataset, concatenate_datasets, load_dataset +from datasets.data_files import EmptyDatasetError + +# First Party +from instructlab.sdg.utils import pandas + +logger = logging.getLogger(__name__) + + +class Checkpointer: + def __init__(self, checkpoint_dir=None, save_freq=1): + self._checkpoint_dir = checkpoint_dir + + self._save_freq = save_freq + self._cache = [] + + def checkpoint(self, dataset): + self._cache.append(dataset) + if len(self._cache) < self._save_freq: + return + self.save() + self._cache.clear() + + def done(self): + if self._cache: + self.save() + self._cache.clear() + + def save(self): + if self._checkpoint_dir is None: + return + checkpoint_id = uuid.uuid4().hex + checkpoint_file = ( + f"{self._checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl" + ) + logger.info(f"Saving checkpoint to {checkpoint_file}") + # Saves all the current records to new file in the checkpoint dir + concatenate_datasets(self._cache).to_json( + checkpoint_file, orient="records", lines=True + ) + + def load(self, dataset: Dataset) -> Dataset: + if self._checkpoint_dir is None: + return dataset, None + + try: + pre_generated_data = load_dataset( + "json", data_dir=self._checkpoint_dir, split="train" + ) + except EmptyDatasetError: + logger.info( + f"No existing checkpoints found in {self._checkpoint_dir}, generating from scratch" + ) + return dataset, None + + logger.info( + f"Loading existing checkpoints from {self._checkpoint_dir}, with {pre_generated_data.num_rows} rows" + ) + seed_data = self._get_missing_data(dataset, pre_generated_data) + logger.info(f"Found {seed_data.num_rows} missing rows in the dataset") + return seed_data, pre_generated_data + + def _get_missing_data(self, seed_data, generated_data): + # Get the common columns between the two datasets + common_columns = list( + set(seed_data.column_names) & set(generated_data.column_names) + ) + + # Extract the relevant data based on common columns + seed_data_common = seed_data.select_columns(common_columns) + generated_data_common = generated_data.select_columns(common_columns) + + # Convert to Pandas DataFrames for easier comparison + seed_df = seed_data_common.to_pandas() + generated_df = generated_data_common.to_pandas() + + # Identify missing rows + missing_rows = ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1)) + + missing_df = seed_data.to_pandas()[missing_rows] + return pandas.dataset_from_pandas_dataframe(missing_df) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 730c8b39..dfe866d4 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -5,6 +5,7 @@ from importlib import resources from pathlib import Path from typing import Optional +import dataclasses import json import os import time @@ -181,6 +182,8 @@ def _context_init( model_family: str, model_id: str, num_instructions_to_generate: int, + checkpoint_dir: str, + save_freq: int, batch_num_workers: Optional[int], batch_size: Optional[int], ): @@ -194,6 +197,8 @@ def _context_init( model_family=model_family, model_id=model_id, num_instructions_to_generate=num_instructions_to_generate, + checkpoint_dir=checkpoint_dir, + save_freq=save_freq, **extra_kwargs, ) @@ -284,6 +289,7 @@ def generate_data( client: Optional[openai.OpenAI] = None, pipeline: Optional[str] = "simple", batch_size: Optional[int] = None, + checkpoint_dir: Optional[str] = None, ) -> None: """Generate data for training and testing a model. @@ -348,13 +354,17 @@ def generate_data( model_family, model_name, num_instructions_to_generate, + checkpoint_dir, + 1, # save_freq batch_size=batch_size, batch_num_workers=num_cpus, ) sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(ctx, pipeline) - mmlu_bench_pipe = mmlubench_pipe_init(ctx) + # Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline) + mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None) + mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx) mixer = _mixer_init(ctx, output_dir, date_suffix) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 1263c974..f01f9eb3 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -13,6 +13,7 @@ import yaml # First Party +from instructlab.sdg.checkpointing import Checkpointer from instructlab.sdg.utils import pandas # Local @@ -61,6 +62,8 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes model_id: str num_instructions_to_generate: int dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS + checkpoint_dir: Optional[str] = None + save_freq: Optional[int] = 1 batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None @@ -129,6 +132,12 @@ def generate(self, dataset) -> Dataset: Generate the dataset by running the pipeline steps. dataset: the input dataset """ + + # The checkpointer allows us to resume from where we left off + # Saving the output of pipe instances along the way + checkpointer = Checkpointer(self.ctx.checkpoint_dir, self.ctx.save_freq) + dataset, pre_generated_data = checkpointer.load(dataset) + # If not batching, simply delegate to _generate_single if not self.ctx.batching_enabled: logger.info("Running pipeline single-threaded") @@ -142,6 +151,7 @@ def generate(self, dataset) -> Dataset: self.ctx.batch_size, ) input_splits = self._split_dataset(dataset) + output_splits = [] with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor: futures = [ executor.submit(self._generate_single, input_split) @@ -150,8 +160,13 @@ def generate(self, dataset) -> Dataset: # Collect the results of each batch as they finish. This needs to # wait for them all, so the order of waiting doesn't matter - output_splits = [future.result() for future in futures] - + for future in futures: + ds = future.result() + output_splits.append(ds) + checkpointer.checkpoint(ds) + checkpointer.done() + if pre_generated_data: + output_splits.append(pre_generated_data) return concatenate_datasets(output_splits) ## Implementation Details ## diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py new file mode 100644 index 00000000..9997e717 --- /dev/null +++ b/tests/test_checkpointing.py @@ -0,0 +1,117 @@ +# Standard +import json +import os + +# Third Party +from datasets import Dataset +import pytest + +# First Party +from instructlab.sdg.checkpointing import Checkpointer + + +def _add_bar(sample, add_value=100): + sample["bar"] = sample["foo"] + add_value + return sample + + +def _populate_checkpoints(tmpdir, dataset, checkpoints_count, remove_column): + for i in range(0, checkpoints_count): + checkpoint_dataset = dataset.select(range(i * 10, (i + 1) * 10)) + checkpoint_dataset = checkpoint_dataset.map( + lambda x: _add_bar(x, add_value=100) + ) + if remove_column: + checkpoint_dataset = checkpoint_dataset.remove_columns("foo") + checkpoint_dataset.to_json( + os.path.join(tmpdir, f"data_checkpoint_abcde{i}.jsonl"), + orient="records", + lines=True, + ) + + +def _validate_checkpoints(tmpdir, expected_files_count, expected_length, remove_column): + saved_files = os.listdir(tmpdir) + assert len(saved_files) == expected_files_count + assert all(f.startswith("data_checkpoint_") for f in saved_files) + assert all(f.endswith(".jsonl") for f in saved_files) + + for f in saved_files: + with open(os.path.join(tmpdir, f), "r") as f: + l = list(f) + if isinstance(expected_length, list): + expected_length.remove(len(l)) + else: + assert len(l) == expected_length + for s in l: + data = json.loads(s) + if remove_column: + assert "foo" not in data and "bar" in data + else: + assert "foo" in data and "bar" in data + + +@pytest.mark.parametrize( + "save_freq, remove_column, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length", + [ + (1, False, 10, 0, 0, 1, 10), + (1, True, 10, 0, 0, 1, 10), + (1, False, 100, 1, 9, 10, 10), + (1, True, 100, 1, 9, 10, 10), + (1, False, 100, 2, 8, 10, 10), + (3, False, 100, 2, 8, 5, [10, 10, 30, 30, 20]), + ], +) +def test_checkpointing( + tmpdir, + save_freq, + remove_column, + dataset_size, + init_checkpoints, + splits, + final_checkpoints, + checkpoint_length, +): + # Our initial dataset + dataset = Dataset.from_list([{"idx": i, "foo": i} for i in range(dataset_size)]) + + # Generate and save some checkpoints to disk + _populate_checkpoints(tmpdir, dataset, init_checkpoints, remove_column) + + # Load checkpoints, giving us the remaining dataset to process and + # the generated data loaded from the checkpoints + checkpointer = Checkpointer(checkpoint_dir=tmpdir, save_freq=save_freq) + dataset, pre_generated_data = checkpointer.load(dataset) + + # Should be present, even if removed from the checkpoint (remove_column=True) + assert "foo" in dataset.features + + # When testing save_freq, we will have checkpoints of different lengths + if isinstance(checkpoint_length, list): + checkpoints_total = sum(checkpoint_length[:init_checkpoints]) + else: + checkpoints_total = checkpoint_length * init_checkpoints + + # Validate pre-generated data loaded from the checkpoints + assert len(dataset) == (dataset_size - checkpoints_total) + if init_checkpoints > 0: + assert len(pre_generated_data) == checkpoints_total + + # Apply pipeline to the remaining dataset and save checkpoints + if splits: + for i in range(0, splits): + split = dataset.select(range(i * 10, (i + 1) * 10)) + split = split.map(lambda x: _add_bar(x, add_value=100)) + if remove_column: + split = split.remove_columns("foo") + checkpointer.checkpoint(split) + else: + dataset = dataset.map(lambda x: _add_bar(x, add_value=10)) + if remove_column: + dataset = dataset.remove_columns("foo") + checkpointer.checkpoint(dataset) + + checkpointer.done() + + # Validate that all checkpoints are now saved to disk + _validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length, remove_column) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 8b47adf7..33d21e8f 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -19,6 +19,8 @@ def test_context_init_batch_size_optional(): "mixtral", "foo.bar", 1, + "/checkpoint/dir", + 1, batch_size=None, batch_num_workers=None, ) @@ -32,6 +34,8 @@ def test_context_init_batch_size_optional(): "mixtral", "foo.bar", 1, + "/checkpoint/dir", + 1, batch_size=20, batch_num_workers=32, )