From ab013e532839527d0464e8bf54b51d47a26e58bf Mon Sep 17 00:00:00 2001 From: shiv Date: Fri, 26 Jul 2024 07:56:48 -0400 Subject: [PATCH 1/2] Add data checkpointing capability Introduce a comprehensive data checkpointing mechanism that saves intermediate states during data generation. This allows for more granular progress tracking and recovery. Checkpoints are saved periodically based on the save_freq setting, preventing data loss and enabling resumption from the last saved state in case of interruptions. The system can resume from a saved checkpoint by comparing the generated data with the seed data to identify and process missing data. Each checkpoint is uniquely identified using UUIDs, ensuring distinct and traceable save points. Co-authored-by: shiv Co-authored-by: Derek Higgins Co-authored-by: Mark McLoughlin Signed-off-by: Derek Higgins Signed-off-by: Mark McLoughlin --- src/instructlab/sdg/checkpointing.py | 87 +++++++++++++++++++++++ src/instructlab/sdg/generate_data.py | 12 +++- src/instructlab/sdg/pipeline.py | 19 ++++- tests/test_checkpointing.py | 102 +++++++++++++++++++++++++++ tests/test_generate_data.py | 4 ++ 5 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 src/instructlab/sdg/checkpointing.py create mode 100644 tests/test_checkpointing.py diff --git a/src/instructlab/sdg/checkpointing.py b/src/instructlab/sdg/checkpointing.py new file mode 100644 index 00000000..a16d18d0 --- /dev/null +++ b/src/instructlab/sdg/checkpointing.py @@ -0,0 +1,87 @@ +# Standard +import logging +import uuid + +# Third Party +from datasets import Dataset, concatenate_datasets, load_dataset +from datasets.data_files import EmptyDatasetError + +# First Party +from instructlab.sdg.utils import pandas + +logger = logging.getLogger(__name__) + + +class Checkpointer: + def __init__(self, checkpoint_dir=None, save_freq=1): + self._checkpoint_dir = checkpoint_dir + + self._save_freq = save_freq + self._cache = [] + + def checkpoint(self, dataset): + self._cache.append(dataset) + if len(self._cache) < self._save_freq: + return + self.save() + self._cache.clear() + + def done(self): + if self._cache: + self.save() + self._cache.clear() + + def save(self): + if self._checkpoint_dir is None: + return + checkpoint_id = uuid.uuid4().hex + checkpoint_file = ( + f"{self._checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl" + ) + logger.info(f"Saving checkpoint to {checkpoint_file}") + # Saves all the current records to new file in the checkpoint dir + concatenate_datasets(self._cache).to_json( + checkpoint_file, orient="records", lines=True + ) + + def load(self, dataset: Dataset) -> Dataset: + if self._checkpoint_dir is None: + return dataset, None + + try: + pre_generated_data = load_dataset( + "json", data_dir=self._checkpoint_dir, split="train" + ) + except EmptyDatasetError: + logger.info( + f"No existing checkpoints found in {self._checkpoint_dir}, generating from scratch" + ) + return dataset, None + + logger.info( + f"Loading existing checkpoints from {self._checkpoint_dir}, with {pre_generated_data.num_rows} rows" + ) + seed_data = self._get_missing_data(dataset, pre_generated_data) + logger.info(f"Found {seed_data.num_rows} missing rows in the dataset") + return seed_data, pre_generated_data + + def _get_missing_data(self, seed_data, generated_data): + # Get the common columns between the two datasets + common_columns = list( + set(seed_data.column_names) & set(generated_data.column_names) + ) + + # Extract the relevant data based on common columns + seed_data_common = seed_data.select_columns(common_columns) + generated_data_common = generated_data.select_columns(common_columns) + + # Convert to Pandas DataFrames for easier comparison + seed_df = seed_data_common.to_pandas() + generated_df = generated_data_common.to_pandas() + + # Identify missing rows + missing_df = seed_df[ + ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1)) + ] + + return pandas.dataset_from_pandas_dataframe(missing_df) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 730c8b39..dfe866d4 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -5,6 +5,7 @@ from importlib import resources from pathlib import Path from typing import Optional +import dataclasses import json import os import time @@ -181,6 +182,8 @@ def _context_init( model_family: str, model_id: str, num_instructions_to_generate: int, + checkpoint_dir: str, + save_freq: int, batch_num_workers: Optional[int], batch_size: Optional[int], ): @@ -194,6 +197,8 @@ def _context_init( model_family=model_family, model_id=model_id, num_instructions_to_generate=num_instructions_to_generate, + checkpoint_dir=checkpoint_dir, + save_freq=save_freq, **extra_kwargs, ) @@ -284,6 +289,7 @@ def generate_data( client: Optional[openai.OpenAI] = None, pipeline: Optional[str] = "simple", batch_size: Optional[int] = None, + checkpoint_dir: Optional[str] = None, ) -> None: """Generate data for training and testing a model. @@ -348,13 +354,17 @@ def generate_data( model_family, model_name, num_instructions_to_generate, + checkpoint_dir, + 1, # save_freq batch_size=batch_size, batch_num_workers=num_cpus, ) sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(ctx, pipeline) - mmlu_bench_pipe = mmlubench_pipe_init(ctx) + # Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline) + mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None) + mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx) mixer = _mixer_init(ctx, output_dir, date_suffix) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 1263c974..f01f9eb3 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -13,6 +13,7 @@ import yaml # First Party +from instructlab.sdg.checkpointing import Checkpointer from instructlab.sdg.utils import pandas # Local @@ -61,6 +62,8 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes model_id: str num_instructions_to_generate: int dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS + checkpoint_dir: Optional[str] = None + save_freq: Optional[int] = 1 batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None @@ -129,6 +132,12 @@ def generate(self, dataset) -> Dataset: Generate the dataset by running the pipeline steps. dataset: the input dataset """ + + # The checkpointer allows us to resume from where we left off + # Saving the output of pipe instances along the way + checkpointer = Checkpointer(self.ctx.checkpoint_dir, self.ctx.save_freq) + dataset, pre_generated_data = checkpointer.load(dataset) + # If not batching, simply delegate to _generate_single if not self.ctx.batching_enabled: logger.info("Running pipeline single-threaded") @@ -142,6 +151,7 @@ def generate(self, dataset) -> Dataset: self.ctx.batch_size, ) input_splits = self._split_dataset(dataset) + output_splits = [] with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor: futures = [ executor.submit(self._generate_single, input_split) @@ -150,8 +160,13 @@ def generate(self, dataset) -> Dataset: # Collect the results of each batch as they finish. This needs to # wait for them all, so the order of waiting doesn't matter - output_splits = [future.result() for future in futures] - + for future in futures: + ds = future.result() + output_splits.append(ds) + checkpointer.checkpoint(ds) + checkpointer.done() + if pre_generated_data: + output_splits.append(pre_generated_data) return concatenate_datasets(output_splits) ## Implementation Details ## diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py new file mode 100644 index 00000000..7f9f7842 --- /dev/null +++ b/tests/test_checkpointing.py @@ -0,0 +1,102 @@ +# Standard +import json +import os + +# Third Party +from datasets import Dataset +import pytest + +# First Party +from instructlab.sdg.checkpointing import Checkpointer + + +def _add_bar(sample, add_value=100): + sample["bar"] = sample["foo"] + add_value + return sample + + +def _populate_checkpoints(tmpdir, dataset, checkpoints_count): + for i in range(0, checkpoints_count): + checkpoint_dataset = dataset.select(range(i * 10, (i + 1) * 10)) + checkpoint_dataset = checkpoint_dataset.map( + lambda x: _add_bar(x, add_value=100) + ) + checkpoint_dataset.to_json( + os.path.join(tmpdir, f"data_checkpoint_abcde{i}.jsonl"), + orient="records", + lines=True, + ) + + +def _validate_checkpoints(tmpdir, expected_files_count, expected_length): + saved_files = os.listdir(tmpdir) + assert len(saved_files) == expected_files_count + assert all(f.startswith("data_checkpoint_") for f in saved_files) + assert all(f.endswith(".jsonl") for f in saved_files) + + for f in saved_files: + with open(os.path.join(tmpdir, f), "r") as f: + l = list(f) + if isinstance(expected_length, list): + expected_length.remove(len(l)) + else: + assert len(l) == expected_length + for s in l: + data = json.loads(s) + assert "foo" in data and "bar" in data + + +@pytest.mark.parametrize( + "save_freq, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length", + [ + (1, 10, 0, 0, 1, 10), + (1, 100, 1, 9, 10, 10), + (1, 100, 2, 8, 10, 10), + (3, 100, 2, 8, 5, [10, 10, 30, 30, 20]), + ], +) +def test_checkpointing( + tmpdir, + save_freq, + dataset_size, + init_checkpoints, + splits, + final_checkpoints, + checkpoint_length, +): + # Our initial dataset + dataset = Dataset.from_list([{"foo": i} for i in range(dataset_size)]) + + # Generate and save some checkpoints to disk + _populate_checkpoints(tmpdir, dataset, init_checkpoints) + + # Load checkpoints, giving us the remaining dataset to process and + # the generated data loaded from the checkpoints + checkpointer = Checkpointer(checkpoint_dir=tmpdir, save_freq=save_freq) + dataset, pre_generated_data = checkpointer.load(dataset) + + # When testing save_freq, we will have checkpoints of different lengths + if isinstance(checkpoint_length, list): + checkpoints_total = sum(checkpoint_length[:init_checkpoints]) + else: + checkpoints_total = checkpoint_length * init_checkpoints + + # Validate pre-generated data loaded from the checkpoints + assert len(dataset) == (dataset_size - checkpoints_total) + if init_checkpoints > 0: + assert len(pre_generated_data) == checkpoints_total + + # Apply pipeline to the remaining dataset and save checkpoints + if splits: + for i in range(0, splits): + split = dataset.select(range(i * 10, (i + 1) * 10)) + split = split.map(lambda x: _add_bar(x, add_value=100)) + checkpointer.checkpoint(split) + else: + dataset = dataset.map(lambda x: _add_bar(x, add_value=10)) + checkpointer.checkpoint(dataset) + + checkpointer.done() + + # Validate that all checkpoints are now saved to disk + _validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 8b47adf7..33d21e8f 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -19,6 +19,8 @@ def test_context_init_batch_size_optional(): "mixtral", "foo.bar", 1, + "/checkpoint/dir", + 1, batch_size=None, batch_num_workers=None, ) @@ -32,6 +34,8 @@ def test_context_init_batch_size_optional(): "mixtral", "foo.bar", 1, + "/checkpoint/dir", + 1, batch_size=20, batch_num_workers=32, ) From f5d22d74c2c8c99733dc1bc2c558b9456b1ffabb Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 26 Jul 2024 21:32:58 +0100 Subject: [PATCH 2/2] checkpointing: fix "missing data" check with removed columns Fix the logic error that means if a pipeline removes a column that was present in the original dataset, then checkpointing causes the column not be present before the pipeline starts. Add a test case to the unit test to cover this, but note I had to add an additional column to ensure there is at least one column in common between the original dataset and the checkpoint dataset. Signed-off-by: Mark McLoughlin --- src/instructlab/sdg/checkpointing.py | 5 ++-- tests/test_checkpointing.py | 37 +++++++++++++++++++--------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/instructlab/sdg/checkpointing.py b/src/instructlab/sdg/checkpointing.py index a16d18d0..9095c557 100644 --- a/src/instructlab/sdg/checkpointing.py +++ b/src/instructlab/sdg/checkpointing.py @@ -80,8 +80,7 @@ def _get_missing_data(self, seed_data, generated_data): generated_df = generated_data_common.to_pandas() # Identify missing rows - missing_df = seed_df[ - ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1)) - ] + missing_rows = ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1)) + missing_df = seed_data.to_pandas()[missing_rows] return pandas.dataset_from_pandas_dataframe(missing_df) diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py index 7f9f7842..9997e717 100644 --- a/tests/test_checkpointing.py +++ b/tests/test_checkpointing.py @@ -15,12 +15,14 @@ def _add_bar(sample, add_value=100): return sample -def _populate_checkpoints(tmpdir, dataset, checkpoints_count): +def _populate_checkpoints(tmpdir, dataset, checkpoints_count, remove_column): for i in range(0, checkpoints_count): checkpoint_dataset = dataset.select(range(i * 10, (i + 1) * 10)) checkpoint_dataset = checkpoint_dataset.map( lambda x: _add_bar(x, add_value=100) ) + if remove_column: + checkpoint_dataset = checkpoint_dataset.remove_columns("foo") checkpoint_dataset.to_json( os.path.join(tmpdir, f"data_checkpoint_abcde{i}.jsonl"), orient="records", @@ -28,7 +30,7 @@ def _populate_checkpoints(tmpdir, dataset, checkpoints_count): ) -def _validate_checkpoints(tmpdir, expected_files_count, expected_length): +def _validate_checkpoints(tmpdir, expected_files_count, expected_length, remove_column): saved_files = os.listdir(tmpdir) assert len(saved_files) == expected_files_count assert all(f.startswith("data_checkpoint_") for f in saved_files) @@ -43,21 +45,27 @@ def _validate_checkpoints(tmpdir, expected_files_count, expected_length): assert len(l) == expected_length for s in l: data = json.loads(s) - assert "foo" in data and "bar" in data + if remove_column: + assert "foo" not in data and "bar" in data + else: + assert "foo" in data and "bar" in data @pytest.mark.parametrize( - "save_freq, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length", + "save_freq, remove_column, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length", [ - (1, 10, 0, 0, 1, 10), - (1, 100, 1, 9, 10, 10), - (1, 100, 2, 8, 10, 10), - (3, 100, 2, 8, 5, [10, 10, 30, 30, 20]), + (1, False, 10, 0, 0, 1, 10), + (1, True, 10, 0, 0, 1, 10), + (1, False, 100, 1, 9, 10, 10), + (1, True, 100, 1, 9, 10, 10), + (1, False, 100, 2, 8, 10, 10), + (3, False, 100, 2, 8, 5, [10, 10, 30, 30, 20]), ], ) def test_checkpointing( tmpdir, save_freq, + remove_column, dataset_size, init_checkpoints, splits, @@ -65,16 +73,19 @@ def test_checkpointing( checkpoint_length, ): # Our initial dataset - dataset = Dataset.from_list([{"foo": i} for i in range(dataset_size)]) + dataset = Dataset.from_list([{"idx": i, "foo": i} for i in range(dataset_size)]) # Generate and save some checkpoints to disk - _populate_checkpoints(tmpdir, dataset, init_checkpoints) + _populate_checkpoints(tmpdir, dataset, init_checkpoints, remove_column) # Load checkpoints, giving us the remaining dataset to process and # the generated data loaded from the checkpoints checkpointer = Checkpointer(checkpoint_dir=tmpdir, save_freq=save_freq) dataset, pre_generated_data = checkpointer.load(dataset) + # Should be present, even if removed from the checkpoint (remove_column=True) + assert "foo" in dataset.features + # When testing save_freq, we will have checkpoints of different lengths if isinstance(checkpoint_length, list): checkpoints_total = sum(checkpoint_length[:init_checkpoints]) @@ -91,12 +102,16 @@ def test_checkpointing( for i in range(0, splits): split = dataset.select(range(i * 10, (i + 1) * 10)) split = split.map(lambda x: _add_bar(x, add_value=100)) + if remove_column: + split = split.remove_columns("foo") checkpointer.checkpoint(split) else: dataset = dataset.map(lambda x: _add_bar(x, add_value=10)) + if remove_column: + dataset = dataset.remove_columns("foo") checkpointer.checkpoint(dataset) checkpointer.done() # Validate that all checkpoints are now saved to disk - _validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length) + _validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length, remove_column)