From 9b9989d17e5d10754bf49f52af9132159c165ec4 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 17 Jan 2025 14:03:18 -0800 Subject: [PATCH] condense --- src/levanter/data/text.py | 52 --------------------------------------- 1 file changed, 52 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c512a9e12..74bbfaa3b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1346,58 +1346,6 @@ def preprocess_chat_example_for_packing( } -def mk_chat_sft_packed_dataset( - config: ChatUrlDataSourceConfig, - tokenizer: PreTrainedTokenizerBase, - Pos: hax.Axis, - *, - max_segments_per_example: int = 4, -) -> AsyncDataset[LmExample]: - """Creates a packed dataset from chat data for more efficient training.""" - source = config.get_shard_source("train") - if source is None: - raise ValueError("No training data source found") - - # Check if we need to manually append EOS - input_ids = tokenizer("hi there")["input_ids"] - should_append_eos = input_ids[-1] != tokenizer.eos_token_id - - # First process into cacheable format - dataset = source.map_batches( - lambda ex: preprocess_chat_example_for_packing(ex, tokenizer, should_append_eos), - batch_size=128, - num_cpus=num_cpus_used_by_tokenizer(tokenizer), - output_exemplar={ - "input_ids": np.zeros(0, dtype=np.int32), - "prompt_length": np.zeros(0, dtype=np.int32) - } - ) - - # Cache the processed data - cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache( - config.cache_dir, - await_finished=True - ) - - # Convert cached dictionaries to PromptCompletions and pack them - def prepare_and_pack(examples: list[dict]) -> list[LmExample]: - completions = [ - PromptCompletion( - ids=ex["input_ids"].tolist(), - prompt_length=int(ex["prompt_length"]) - ) for ex in examples - ] - return list(pack_prompt_completions( - Pos=Pos, - sequences=completions, - pad_token=tokenizer.pad_token_id, - max_segments_per_example=max_segments_per_example, - )) - - # Pack the examples - return cached_dataset.map_batches(prepare_and_pack) - - def datasource_from_chat_jsonl( urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" ) -> "ShardedDataSource[dict]":