From 3a534ddcc349c155a06921b8fe31e2872a9dc7c8 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 17 Jan 2025 14:12:08 -0800 Subject: [PATCH] fix too many tokens --- src/levanter/data/text.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 74bbfaa3b..0af3ea026 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1318,7 +1318,8 @@ def datasource_from_chat_jsonl( def preprocess_chat_example_for_packing( batch, tokenizer: PreTrainedTokenizerBase, - should_append_eos: bool + should_append_eos: bool, + Pos: hax.Axis, ) -> dict: """Preprocesses chat examples into a cacheable format for packing.""" input_ids_list = [] @@ -1336,8 +1337,16 @@ def preprocess_chat_example_for_packing( full_sequence = example["input"] + target full_ids = tokenizer(full_sequence, truncation=True)["input_ids"] + # Take last Pos.size tokens if sequence is too long + if len(full_ids) > Pos.size: + full_ids = full_ids[-Pos.size:] + # Adjust prompt length if it would now be longer than sequence + input_len = min(len(input_ids), len(full_ids)) + else: + input_len = len(input_ids) + input_ids_list.append(np.array(full_ids, dtype=np.int32)) - prompt_lengths.append(len(input_ids)) + prompt_lengths.append(input_len) # Return a dictionary of numpy arrays that can be cached return { @@ -1416,7 +1425,12 @@ def mk_chat_sft_packed_dataset( # First process into cacheable format dataset = source.map_batches( - lambda ex: preprocess_chat_example_for_packing(ex, tokenizer, should_append_eos), + lambda ex: preprocess_chat_example_for_packing( + ex, + tokenizer, + should_append_eos, + Pos=Pos + ), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar={