diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 53e62512fe..8e9709d4fd 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -4,7 +4,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72 flax>=0.11.0 google-api-python-client google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip -grain[parquet]>=0.2.12 +grain[parquet] jaxtyping jsonlines mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 8aba5a988e..70fcc600f3 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -477,6 +477,11 @@ tokenize_train_data: True # False if the dataset is pre-tokenized tokenize_eval_data: True # False if the dataset is pre-tokenized add_bos: True add_eos: True +# If False, use chunking for long sequences instead of truncation. +# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline. +# See the TokenizeAndTrim and TokenizeAndChunk classes in +# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details. +use_truncation: True # Dataset per_device_batch_size: 12.0 diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 1749fd8a54..b24e78896a 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -95,9 +95,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) assert len(data_columns) == 1 - rekey_dict = {"inputs": "text", "targets": "text"} - dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) - data_columns = ("inputs", "targets") + text_column = data_columns[0] tokenizer_model = tokenizer.build_tokenizer( config.tokenizer_path, @@ -115,11 +113,28 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra pad_id = -1 if tokenize: - dataset = dataset.map( + if config.use_truncation: + dataset = dataset.map( _grain_tokenizer.TokenizeAndTrim( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + text_column, config.max_target_length, tokenizer_model ) - ) + ) + else: + dataset = grain.experimental.WithOptionsIterDataset( + dataset, + options=grain.experimental.DatasetOptions() + ) + dataset = grain.experimental.apply_transformations( + dataset, + _grain_tokenizer.TokenizeAndChunk( + text_column, config.max_target_length, tokenizer_model + ) + ) + + data_columns = ("inputs", "targets") + rekey_dict = {col: text_column for col in data_columns} + dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) + # Pack and Batch examples. batch_size = config.global_batch_size_to_load // jax.process_count() if config.packing: @@ -173,7 +188,7 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo if tokenize: dataset = dataset.map( _grain_tokenizer.TokenizeAndTrim( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + data_columns, config.max_target_length, tokenizer_model ) ) diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index 50155d8f7b..6e9177a945 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -24,35 +24,34 @@ @dataclasses.dataclass -class TokenizeAndTrim(grain.MapTransform): - """Tokenize and trim features to sequence length.""" +class TokenizerTransformBase: + """Base class for tokenizer transforms with common functionality.""" # pylint: disable=attribute-defined-outside-init feature_names: str | Sequence[str] sequence_length: int | Sequence[int] - add_bos: bool - add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer def __post_init__(self): self._processor = None self._initialize_processor_lock = threading.Lock() + # Convert single values to lists for consistent processing if isinstance(self.feature_names, str): self.feature_names = [self.feature_names] if isinstance(self.sequence_length, int): self.sequence_length = [self.sequence_length] * len(self.feature_names) - def map(self, element: dict[str, Any]) -> dict[str, Any]: - """Maps to each element.""" + def _get_processor(self): if self._processor is None: with self._initialize_processor_lock: - if self._processor is None: # Ensures only one thread initializes SPP. + if self._processor is None: # Ensures only one thread initializes processor. self._processor = self.tokenizer - for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): - text = element[feature_name] - token_ids = self._processor.encode(text)[:sequence_length] - element[feature_name] = np.asarray(token_ids, dtype=np.int32) - return element + return self._processor + + def _encode(self, text: str) -> list[int]: + """Common method to encode text using the tokenizer.""" + processor = self._get_processor() + return processor.encode(text) def __getstate__(self): state = self.__dict__.copy() @@ -64,3 +63,51 @@ def __setstate__(self, state): self.__dict__.update(state) self._processor = None self._initialize_processor_lock = threading.Lock() + + +@dataclasses.dataclass +class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform): + """Tokenize and trim features to sequence length.""" + + def __post_init__(self): + super().__post_init__() + + def map(self, element: dict[str, Any]) -> dict[str, Any]: + """Maps to each element.""" + for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True): + text = element[feature_name] + token_ids = self._encode(text)[:max_length] + element[feature_name] = np.asarray(token_ids, dtype=np.int32) + return element + + +@dataclasses.dataclass +class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform): + """Tokenize and chunk features into multiple examples of sequence length.""" + + max_fan_out: int = 2048 + + def __post_init__(self): + super().__post_init__() + # TokenizeAndChunk only supports single feature for chunking + assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name" + assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length" + self.feature_name = self.feature_names[0] # For backward compatibility + self.sequence_length = self.sequence_length[0] # Convert back to int for chunking + + def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: + text = element[self.feature_name] + chunk_size = self.sequence_length + + token_ids = self._encode(text) + + if not token_ids: + return [] + + output_elements = [] + for start_idx in range(0, len(token_ids), chunk_size): + chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32) + new_element = {self.feature_name: chunk} + output_elements.append(new_element) + + return output_elements \ No newline at end of file diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py new file mode 100644 index 0000000000..965c0ffbb2 --- /dev/null +++ b/tests/tokenizer_transform_test.py @@ -0,0 +1,152 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for tokenizer +""" + +import unittest + +import grain.python as grain +import numpy as np +from MaxText.input_pipeline import _grain_tokenizer +from MaxText.input_pipeline import _input_pipeline_utils +from numpy.testing import assert_array_equal + + +class MockTokenizer: + """ + Mocks a tokenizer by splitting on space and mapping letters to simple ints. + e.g., "a b c" -> [1, 2, 3] + """ + def encode(self, text: str) -> list[int]: + if not text: + return [] + # Simple 'a'=1, 'b'=2, ... mapping + return [ord(c) - ord('a') + 1 for c in text.split(' ')] + + +class TokenizerTransformTest(unittest.TestCase): + """Tests for chunking, trimming, and padding transformations.""" + + def setUp(self): + self.max_len = 5 + self.pad_length = 7 + self.pad_id = 0 + self.feature_names = "text" + self.mock_tokenizer = MockTokenizer() + self.source_data = [ + {"text": "a b c"}, + {"text": "d e f g h i j"}, + {"text": ""}, + {"text": "k l m n o p q r s t"} + ] + self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset() + + def test_tokenize_and_trim(self): + """Tests the 1:1 MapTransform (truncation) logic.""" + trim_op = _grain_tokenizer.TokenizeAndTrim( + feature_names=self.feature_names, + sequence_length=self.max_len, + tokenizer=self.mock_tokenizer + ) + trim_ds = self.base_ds.map(trim_op) + results = list(trim_ds) + self.assertEqual(len(results), len(self.source_data)) + expected_inputs = [ + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5, 6, 7, 8], dtype=np.int32), + np.array([], dtype=np.int32), + np.array([11, 12, 13, 14, 15], dtype=np.int32) + ] + result_inputs = [r["text"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_tokenize_and_chunk(self): + """Tests the 1:N FlatMapTransform (chunking) logic.""" + chunk_op = _grain_tokenizer.TokenizeAndChunk( + feature_names=self.feature_names, + sequence_length=self.max_len, + tokenizer=self.mock_tokenizer + ) + chunk_ds = self.base_ds.apply(chunk_op) + results = list(chunk_ds) + self.assertEqual(len(results), 5) + expected_inputs = [ + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5, 6, 7, 8], dtype=np.int32), + np.array([9, 10], dtype=np.int32), + np.array([11, 12, 13, 14, 15], dtype=np.int32), + np.array([16, 17, 18, 19, 20], dtype=np.int32) + ] + result_inputs = [r["text"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_trim_and_pad_chaining(self): + """Tests chaining TokenizeAndTrim.map() -> PadOrTrimToMaxLength.map()""" + trim_op = _grain_tokenizer.TokenizeAndTrim( + feature_names=self.feature_names, + sequence_length=self.max_len, + tokenizer=self.mock_tokenizer + ) + pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( + max_length=self.pad_length, + pad_id=self.pad_id + ) + chained_ds = self.base_ds.map(trim_op).map(pad_op) + results = list(chained_ds) + self.assertEqual(len(results), len(self.source_data)) + expected_inputs = [ + np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), + np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), + np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32), + np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32) + ] + result_inputs = [r["text"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_chunk_and_pad_chaining(self): + """Tests chaining TokenizeAndChunk.apply() -> PadOrTrimToMaxLength.map()""" + chunk_op = _grain_tokenizer.TokenizeAndChunk( + feature_names=self.feature_names, + sequence_length=self.max_len, + tokenizer=self.mock_tokenizer + ) + pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( + max_length=self.pad_length, + pad_id=self.pad_id + ) + chained_ds = self.base_ds.apply(chunk_op).map(pad_op) + results = list(chained_ds) + self.assertEqual(len(results), 5) + expected_inputs = [ + np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), + np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), + np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32), + np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32), + np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32), + ] + result_inputs = [r["text"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + +if __name__ == "__main__": + unittest.main()