From 2c0512e1a321ca1ef8ba9550d34e73660367dfca Mon Sep 17 00:00:00 2001 From: bzantium Date: Tue, 16 Sep 2025 22:16:18 +0900 Subject: [PATCH 01/13] feat(input_pipeline): Add support for chunking long sequences instead of truncation --- src/MaxText/configs/base.yml | 1 + .../input_pipeline/_grain_data_processing.py | 17 +- .../input_pipeline/_grain_tokenizer.py | 61 +++++++ tests/tokenizer_transform_test.py | 160 ++++++++++++++++++ 4 files changed, 234 insertions(+), 5 deletions(-) create mode 100644 tests/tokenizer_transform_test.py diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index ad2575c257..2edc0dc360 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -458,6 +458,7 @@ tokenize_train_data: True # False if the dataset is pre-tokenized tokenize_eval_data: True # False if the dataset is pre-tokenized add_bos: True add_eos: True +use_truncation: True # If False, use chunking for long sequences instead of truncation # Dataset per_device_batch_size: 12.0 diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index b7adb6d93c..ae7d6e73e8 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -115,11 +115,18 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra pad_id = -1 if tokenize: - dataset = dataset.map( - _grain_tokenizer.TokenizeAndTrim( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model - ) - ) + if config.use_truncation: + dataset = dataset.map( + _grain_tokenizer.TokenizeAndTrim( + data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + ) + ) + else: + dataset = dataset.apply( + _grain_tokenizer.TokenizeAndChunk( + data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + ) + ) # Pack and Batch examples. if config.packing: diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index 50155d8f7b..c7d7260c9d 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -64,3 +64,64 @@ def __setstate__(self, state): self.__dict__.update(state) self._processor = None self._initialize_processor_lock = threading.Lock() + + +@dataclasses.dataclass +class TokenizeAndChunk(grain.experimental.FlatMapTransform): + """Tokenize and chunk features into multiple examples of sequence length.""" + + # pylint: disable=attribute-defined-outside-init + feature_names: str | Sequence[str] + sequence_length: int | Sequence[int] + add_bos: bool + add_eos: bool + tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer + max_fan_out: int = 2048 + + def __post_init__(self): + self._processor = None + self._initialize_processor_lock = threading.Lock() + if isinstance(self.feature_names, str): + self.feature_names = [self.feature_names] + if isinstance(self.sequence_length, int): + self.sequence_length = [self.sequence_length] * len(self.feature_names) + + def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: + """Maps one element to a LIST of chunked elements.""" + if self._processor is None: + with self._initialize_processor_lock: + if self._processor is None: # Ensures only one thread initializes SPP. + self._processor = self.tokenizer + + primary_feature_name = self.feature_names[0] + max_len = self.sequence_length[0] + text = element[primary_feature_name] + + token_ids = self._processor.encode(text) + + if not token_ids: + return [] + + output_elements = [] + + for i in range(0, len(token_ids), max_len): + chunk = token_ids[i : i + max_len] + chunk_np = np.asarray(chunk, dtype=np.int32) + new_element = element.copy() + + for feature_name in self.feature_names: + new_element[feature_name] = chunk_np + + output_elements.append(new_element) + return output_elements + + def __getstate__(self): + state = self.__dict__.copy() + del state["_processor"] + del state["_initialize_processor_lock"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._processor = None + self._initialize_processor_lock = threading.Lock() \ No newline at end of file diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py new file mode 100644 index 0000000000..3a995776c5 --- /dev/null +++ b/tests/tokenizer_transform_test.py @@ -0,0 +1,160 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for tokenizer +""" + +import unittest + +import grain.python as grain +import numpy as np +from MaxText.input_pipeline import _grain_tokenizer +from MaxText.input_pipeline import _input_pipeline_utils +from numpy.testing import assert_array_equal + + +class MockTokenizer: + """ + Mocks a tokenizer by splitting on space and mapping letters to simple ints. + e.g., "a b c" -> [1, 2, 3] + """ + def encode(self, text: str) -> list[int]: + if not text: + return [] + # Simple 'a'=1, 'b'=2, ... mapping + return [ord(c) - ord('a') + 1 for c in text.split(' ')] + + +class TokenizerTransformTest(unittest.TestCase): + """Tests for chunking, trimming, and padding transformations.""" + + def setUp(self): + self.max_len = 5 + self.pad_length = 7 + self.pad_id = 0 + self.feature_names = ["inputs", "targets"] + self.mock_tokenizer = MockTokenizer() + self.source_data = [ + {"inputs": "a b c", "targets": "a b c"}, + {"inputs": "d e f g h i j", "targets": "d e f g h i j"}, + {"inputs": "", "targets": ""}, + {"inputs": "k l m n o p q r s t", "targets": "k l m n o p q r s t"} + ] + self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset() + + def test_tokenize_and_trim(self): + """Tests the 1:1 MapTransform (truncation) logic.""" + trim_op = _grain_tokenizer.TokenizeAndTrim( + feature_names=self.feature_names, + sequence_length=self.max_len, + add_bos=False, + add_eos=False, + tokenizer=self.mock_tokenizer + ) + trim_ds = self.base_ds.map(trim_op) + results = list(trim_ds) + self.assertEqual(len(results), len(self.source_data)) + expected_inputs = [ + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5, 6, 7, 8], dtype=np.int32), + np.array([], dtype=np.int32), + np.array([11, 12, 13, 14, 15], dtype=np.int32) + ] + result_inputs = [r["inputs"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_tokenize_and_chunk(self): + """Tests the 1:N FlatMapTransform (chunking) logic.""" + chunk_op = _grain_tokenizer.TokenizeAndChunk( + feature_names=self.feature_names, + sequence_length=self.max_len, + add_bos=False, + add_eos=False, + tokenizer=self.mock_tokenizer + ) + chunk_ds = self.base_ds.apply(chunk_op) + results = list(chunk_ds) + self.assertEqual(len(results), 5) + expected_inputs = [ + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5, 6, 7, 8], dtype=np.int32), + np.array([9, 10], dtype=np.int32), + np.array([11, 12, 13, 14, 15], dtype=np.int32), + np.array([16, 17, 18, 19, 20], dtype=np.int32) + ] + result_inputs = [r["inputs"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_trim_and_pad_chaining(self): + """Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()""" + trim_op = _grain_tokenizer.TokenizeAndTrim( + feature_names=self.feature_names, + sequence_length=self.max_len, + add_bos=False, + add_eos=False, + tokenizer=self.mock_tokenizer + ) + pad_op = _input_pipeline_utils.PadToMaxLength( + max_length=self.pad_length, + pad_id=self.pad_id + ) + chained_ds = self.base_ds.map(trim_op).map(pad_op) + results = list(chained_ds) + self.assertEqual(len(results), len(self.source_data)) + expected_inputs = [ + np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), + np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), + np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32), + np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32) + ] + result_inputs = [r["inputs"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + def test_chunk_and_pad_chaining(self): + """Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()""" + chunk_op = _grain_tokenizer.TokenizeAndChunk( + feature_names=self.feature_names, + sequence_length=self.max_len, + add_bos=False, + add_eos=False, + tokenizer=self.mock_tokenizer + ) + pad_op = _input_pipeline_utils.PadToMaxLength( + max_length=self.pad_length, + pad_id=self.pad_id + ) + chained_ds = self.base_ds.apply(chunk_op).map(pad_op) + results = list(chained_ds) + self.assertEqual(len(results), 5) + expected_inputs = [ + np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), + np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), + np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32), + np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32), + np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32), + ] + result_inputs = [r["inputs"] for r in results] + self.assertEqual(len(result_inputs), len(expected_inputs)) + for res, exp in zip(result_inputs, expected_inputs): + assert_array_equal(res, exp) + + +if __name__ == "__main__": + unittest.main() From 4138263627443572191aaa1bbe279969a5945252 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 18 Sep 2025 09:00:52 +0900 Subject: [PATCH 02/13] docs(config): Clarify use_truncation flag with implementation details --- src/MaxText/configs/base.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 2edc0dc360..2f65fcbe0a 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -458,7 +458,10 @@ tokenize_train_data: True # False if the dataset is pre-tokenized tokenize_eval_data: True # False if the dataset is pre-tokenized add_bos: True add_eos: True -use_truncation: True # If False, use chunking for long sequences instead of truncation +# If False, use chunking for long sequences instead of truncation. +# See the TokenizeAndTrim and TokenizeAndChunk classes in +# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details. +use_truncation: True # Dataset per_device_batch_size: 12.0 From b0b78c3324e1925289c9563453f2dfac8a4cb4d9 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 18 Sep 2025 09:01:05 +0900 Subject: [PATCH 03/13] refactor(input_pipeline): Decouple tokenization from rekeying and simplify transforms --- .../input_pipeline/_grain_data_processing.py | 10 +++-- .../input_pipeline/_grain_tokenizer.py | 39 +++++++------------ tests/tokenizer_transform_test.py | 26 ++++++------- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index ae7d6e73e8..16322ead4f 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -95,9 +95,9 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) assert len(data_columns) == 1 - rekey_dict = {"inputs": "text", "targets": "text"} - dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) + text_column = "text" data_columns = ("inputs", "targets") + rekey_dict = {col: text_column for col in data_columns} tokenizer_model = tokenizer.build_tokenizer( config.tokenizer_path, @@ -118,16 +118,18 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra if config.use_truncation: dataset = dataset.map( _grain_tokenizer.TokenizeAndTrim( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model ) ) else: dataset = dataset.apply( _grain_tokenizer.TokenizeAndChunk( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model ) ) + dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) + # Pack and Batch examples. if config.packing: length_struct = {col: config.max_target_length for col in data_columns} diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index c7d7260c9d..8e04786bba 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -28,8 +28,8 @@ class TokenizeAndTrim(grain.MapTransform): """Tokenize and trim features to sequence length.""" # pylint: disable=attribute-defined-outside-init - feature_names: str | Sequence[str] - sequence_length: int | Sequence[int] + text_column: str + sequence_length: int add_bos: bool add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer @@ -37,10 +37,6 @@ class TokenizeAndTrim(grain.MapTransform): def __post_init__(self): self._processor = None self._initialize_processor_lock = threading.Lock() - if isinstance(self.feature_names, str): - self.feature_names = [self.feature_names] - if isinstance(self.sequence_length, int): - self.sequence_length = [self.sequence_length] * len(self.feature_names) def map(self, element: dict[str, Any]) -> dict[str, Any]: """Maps to each element.""" @@ -48,10 +44,10 @@ def map(self, element: dict[str, Any]) -> dict[str, Any]: with self._initialize_processor_lock: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = self.tokenizer - for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): - text = element[feature_name] - token_ids = self._processor.encode(text)[:sequence_length] - element[feature_name] = np.asarray(token_ids, dtype=np.int32) + + text = element[self.text_column] + token_ids = self._processor.encode(text)[:self.sequence_length] + element[self.text_column] = np.asarray(token_ids, dtype=np.int32) return element def __getstate__(self): @@ -71,8 +67,8 @@ class TokenizeAndChunk(grain.experimental.FlatMapTransform): """Tokenize and chunk features into multiple examples of sequence length.""" # pylint: disable=attribute-defined-outside-init - feature_names: str | Sequence[str] - sequence_length: int | Sequence[int] + text_column: str + sequence_length: int add_bos: bool add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer @@ -81,10 +77,6 @@ class TokenizeAndChunk(grain.experimental.FlatMapTransform): def __post_init__(self): self._processor = None self._initialize_processor_lock = threading.Lock() - if isinstance(self.feature_names, str): - self.feature_names = [self.feature_names] - if isinstance(self.sequence_length, int): - self.sequence_length = [self.sequence_length] * len(self.feature_names) def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: """Maps one element to a LIST of chunked elements.""" @@ -93,25 +85,20 @@ def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = self.tokenizer - primary_feature_name = self.feature_names[0] - max_len = self.sequence_length[0] - text = element[primary_feature_name] + text = element.pop(self.text_column) + max_len = self.sequence_length token_ids = self._processor.encode(text) if not token_ids: return [] + token_ids = np.array(token_ids, dtype=np.int32) + output_elements = [] for i in range(0, len(token_ids), max_len): - chunk = token_ids[i : i + max_len] - chunk_np = np.asarray(chunk, dtype=np.int32) - new_element = element.copy() - - for feature_name in self.feature_names: - new_element[feature_name] = chunk_np - + new_element = {**element, self.text_column: token_ids[i : i + max_len]} output_elements.append(new_element) return output_elements diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py index 3a995776c5..7e98758c37 100644 --- a/tests/tokenizer_transform_test.py +++ b/tests/tokenizer_transform_test.py @@ -43,20 +43,20 @@ def setUp(self): self.max_len = 5 self.pad_length = 7 self.pad_id = 0 - self.feature_names = ["inputs", "targets"] + self.text_column = "text" self.mock_tokenizer = MockTokenizer() self.source_data = [ - {"inputs": "a b c", "targets": "a b c"}, - {"inputs": "d e f g h i j", "targets": "d e f g h i j"}, - {"inputs": "", "targets": ""}, - {"inputs": "k l m n o p q r s t", "targets": "k l m n o p q r s t"} + {"text": "a b c"}, + {"text": "d e f g h i j"}, + {"text": ""}, + {"text": "k l m n o p q r s t"} ] self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset() def test_tokenize_and_trim(self): """Tests the 1:1 MapTransform (truncation) logic.""" trim_op = _grain_tokenizer.TokenizeAndTrim( - feature_names=self.feature_names, + text_column=self.text_column, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -71,7 +71,7 @@ def test_tokenize_and_trim(self): np.array([], dtype=np.int32), np.array([11, 12, 13, 14, 15], dtype=np.int32) ] - result_inputs = [r["inputs"] for r in results] + result_inputs = [r["text"] for r in results] self.assertEqual(len(result_inputs), len(expected_inputs)) for res, exp in zip(result_inputs, expected_inputs): assert_array_equal(res, exp) @@ -79,7 +79,7 @@ def test_tokenize_and_trim(self): def test_tokenize_and_chunk(self): """Tests the 1:N FlatMapTransform (chunking) logic.""" chunk_op = _grain_tokenizer.TokenizeAndChunk( - feature_names=self.feature_names, + text_column=self.text_column, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -95,7 +95,7 @@ def test_tokenize_and_chunk(self): np.array([11, 12, 13, 14, 15], dtype=np.int32), np.array([16, 17, 18, 19, 20], dtype=np.int32) ] - result_inputs = [r["inputs"] for r in results] + result_inputs = [r["text"] for r in results] self.assertEqual(len(result_inputs), len(expected_inputs)) for res, exp in zip(result_inputs, expected_inputs): assert_array_equal(res, exp) @@ -103,7 +103,7 @@ def test_tokenize_and_chunk(self): def test_trim_and_pad_chaining(self): """Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()""" trim_op = _grain_tokenizer.TokenizeAndTrim( - feature_names=self.feature_names, + text_column=self.text_column, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -122,7 +122,7 @@ def test_trim_and_pad_chaining(self): np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32), np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32) ] - result_inputs = [r["inputs"] for r in results] + result_inputs = [r["text"] for r in results] self.assertEqual(len(result_inputs), len(expected_inputs)) for res, exp in zip(result_inputs, expected_inputs): assert_array_equal(res, exp) @@ -130,7 +130,7 @@ def test_trim_and_pad_chaining(self): def test_chunk_and_pad_chaining(self): """Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()""" chunk_op = _grain_tokenizer.TokenizeAndChunk( - feature_names=self.feature_names, + text_column=self.text_column, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -150,7 +150,7 @@ def test_chunk_and_pad_chaining(self): np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32), np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32), ] - result_inputs = [r["inputs"] for r in results] + result_inputs = [r["text"] for r in results] self.assertEqual(len(result_inputs), len(expected_inputs)) for res, exp in zip(result_inputs, expected_inputs): assert_array_equal(res, exp) From 755ad01d32c2c0cf3c0c204647c4efcc22b57f61 Mon Sep 17 00:00:00 2001 From: bzantium Date: Fri, 19 Sep 2025 11:35:50 +0900 Subject: [PATCH 04/13] Add comment explaining TokenizeAndChunk behavior - Added comment that TokenizeAndChunk removes all columns except the text_column --- src/MaxText/input_pipeline/_grain_data_processing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 9cb80cf78d..1066748fe8 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -95,9 +95,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) assert len(data_columns) == 1 - text_column = "text" - data_columns = ("inputs", "targets") - rekey_dict = {col: text_column for col in data_columns} + text_column = data_columns[0] tokenizer_model = tokenizer.build_tokenizer( config.tokenizer_path, @@ -128,6 +126,8 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra ) ) + data_columns = ("inputs", "targets") + rekey_dict = {col: text_column for col in data_columns} dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) # Pack and Batch examples. From 8334e3218c74e359fc6feebbd33e31f16bffb3ac Mon Sep 17 00:00:00 2001 From: bzantium Date: Fri, 19 Sep 2025 11:36:11 +0900 Subject: [PATCH 05/13] Update grain tokenizer implementation - Modified _grain_tokenizer.py with latest changes --- .../input_pipeline/_grain_tokenizer.py | 81 +++++++++---------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index 8e04786bba..2398cd7698 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -24,12 +24,10 @@ @dataclasses.dataclass -class TokenizeAndTrim(grain.MapTransform): - """Tokenize and trim features to sequence length.""" +class TokenizerTransformBase: + """Base class for tokenizer transforms with common functionality.""" # pylint: disable=attribute-defined-outside-init - text_column: str - sequence_length: int add_bos: bool add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer @@ -38,17 +36,12 @@ def __post_init__(self): self._processor = None self._initialize_processor_lock = threading.Lock() - def map(self, element: dict[str, Any]) -> dict[str, Any]: - """Maps to each element.""" + def _get_processor(self): if self._processor is None: with self._initialize_processor_lock: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = self.tokenizer - - text = element[self.text_column] - token_ids = self._processor.encode(text)[:self.sequence_length] - element[self.text_column] = np.asarray(token_ids, dtype=np.int32) - return element + return self._processor def __getstate__(self): state = self.__dict__.copy() @@ -63,52 +56,52 @@ def __setstate__(self, state): @dataclasses.dataclass -class TokenizeAndChunk(grain.experimental.FlatMapTransform): +class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform): + """Tokenize and trim features to sequence length.""" + # pylint: disable=attribute-defined-outside-init + feature_names: str | Sequence[str] + sequence_length: int | Sequence[int] + + def __post_init__(self): + super().__post_init__() + if isinstance(self.feature_names, str): + self.feature_names = [self.feature_names] + if isinstance(self.sequence_length, int): + self.sequence_length = [self.sequence_length] * len(self.feature_names) + + def map(self, element: dict[str, Any]) -> dict[str, Any]: + """Maps to each element.""" + processor = self._get_processor() + for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): + text = element[feature_name] + token_ids = processor.encode(text)[:sequence_length] + element[feature_name] = np.asarray(token_ids, dtype=np.int32) + return element + + +@dataclasses.dataclass +class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform): """Tokenize and chunk features into multiple examples of sequence length.""" # pylint: disable=attribute-defined-outside-init - text_column: str + feature_name: str sequence_length: int - add_bos: bool - add_eos: bool - tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer max_fan_out: int = 2048 - def __post_init__(self): - self._processor = None - self._initialize_processor_lock = threading.Lock() - def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: - """Maps one element to a LIST of chunked elements.""" - if self._processor is None: - with self._initialize_processor_lock: - if self._processor is None: # Ensures only one thread initializes SPP. - self._processor = self.tokenizer - - text = element.pop(self.text_column) + processor = self._get_processor() + text = element[self.feature_name] max_len = self.sequence_length - token_ids = self._processor.encode(text) + token_ids = processor.encode(text) if not token_ids: return [] - token_ids = np.array(token_ids, dtype=np.int32) - output_elements = [] - for i in range(0, len(token_ids), max_len): - new_element = {**element, self.text_column: token_ids[i : i + max_len]} + chunk = np.asarray(token_ids[i : i + max_len], dtype=np.int32) + new_element = {self.feature_name: chunk} output_elements.append(new_element) - return output_elements - - def __getstate__(self): - state = self.__dict__.copy() - del state["_processor"] - del state["_initialize_processor_lock"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._processor = None - self._initialize_processor_lock = threading.Lock() \ No newline at end of file + + return output_elements \ No newline at end of file From d4dfc192156e8c3b29249ae907d4cf435ff555d9 Mon Sep 17 00:00:00 2001 From: bzantium Date: Fri, 19 Sep 2025 11:36:27 +0900 Subject: [PATCH 06/13] Add documentation for use_truncation config - Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline --- src/MaxText/configs/base.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 1489221a78..853c6ec88d 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -459,6 +459,7 @@ tokenize_eval_data: True # False if the dataset is pre-tokenized add_bos: True add_eos: True # If False, use chunking for long sequences instead of truncation. +# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline. # See the TokenizeAndTrim and TokenizeAndChunk classes in # `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details. use_truncation: True From fb770c7d96a0a730067cb19b32113878035a2af3 Mon Sep 17 00:00:00 2001 From: bzantium Date: Sat, 20 Sep 2025 14:41:24 +0900 Subject: [PATCH 07/13] Refactor tokenizer classes to use common variables in base class - Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase - Consolidate initialization logic in base class __post_init__ - Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters - Add common _encode method to eliminate code duplication - Maintain backward compatibility and specialized behavior for each class --- .../input_pipeline/_grain_tokenizer.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index 2398cd7698..ed9b2d8073 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -28,6 +28,8 @@ class TokenizerTransformBase: """Base class for tokenizer transforms with common functionality.""" # pylint: disable=attribute-defined-outside-init + feature_names: str | Sequence[str] + sequence_length: int | Sequence[int] add_bos: bool add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer @@ -35,14 +37,24 @@ class TokenizerTransformBase: def __post_init__(self): self._processor = None self._initialize_processor_lock = threading.Lock() + # Convert single values to lists for consistent processing + if isinstance(self.feature_names, str): + self.feature_names = [self.feature_names] + if isinstance(self.sequence_length, int): + self.sequence_length = [self.sequence_length] * len(self.feature_names) def _get_processor(self): if self._processor is None: with self._initialize_processor_lock: - if self._processor is None: # Ensures only one thread initializes SPP. + if self._processor is None: # Ensures only one thread initializes processor. self._processor = self.tokenizer return self._processor + def _encode(self, text: str) -> list[int]: + """Common method to encode text using the tokenizer.""" + processor = self._get_processor() + return processor.encode(text) + def __getstate__(self): state = self.__dict__.copy() del state["_processor"] @@ -58,23 +70,15 @@ def __setstate__(self, state): @dataclasses.dataclass class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform): """Tokenize and trim features to sequence length.""" - # pylint: disable=attribute-defined-outside-init - feature_names: str | Sequence[str] - sequence_length: int | Sequence[int] def __post_init__(self): super().__post_init__() - if isinstance(self.feature_names, str): - self.feature_names = [self.feature_names] - if isinstance(self.sequence_length, int): - self.sequence_length = [self.sequence_length] * len(self.feature_names) def map(self, element: dict[str, Any]) -> dict[str, Any]: """Maps to each element.""" - processor = self._get_processor() - for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): + for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True): text = element[feature_name] - token_ids = processor.encode(text)[:sequence_length] + token_ids = self._encode(text)[:max_length] element[feature_name] = np.asarray(token_ids, dtype=np.int32) return element @@ -83,24 +87,28 @@ def map(self, element: dict[str, Any]) -> dict[str, Any]: class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform): """Tokenize and chunk features into multiple examples of sequence length.""" - # pylint: disable=attribute-defined-outside-init - feature_name: str - sequence_length: int max_fan_out: int = 2048 + def __post_init__(self): + super().__post_init__() + # TokenizeAndChunk only supports single feature for chunking + assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name" + assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length" + self.feature_name = self.feature_names[0] # For backward compatibility + self.sequence_length = self.sequence_length[0] # Convert back to int for chunking + def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: - processor = self._get_processor() text = element[self.feature_name] - max_len = self.sequence_length + chunk_size = self.sequence_length - token_ids = processor.encode(text) + token_ids = self._encode(text) if not token_ids: return [] output_elements = [] - for i in range(0, len(token_ids), max_len): - chunk = np.asarray(token_ids[i : i + max_len], dtype=np.int32) + for start_idx in range(0, len(token_ids), chunk_size): + chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32) new_element = {self.feature_name: chunk} output_elements.append(new_element) From 591376fff36749b3ec9aa016ed61f4c2eb2a467d Mon Sep 17 00:00:00 2001 From: bzantium Date: Tue, 23 Sep 2025 08:21:42 +0900 Subject: [PATCH 08/13] Update tokenizer transform test --- tests/tokenizer_transform_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py index 7e98758c37..550114144c 100644 --- a/tests/tokenizer_transform_test.py +++ b/tests/tokenizer_transform_test.py @@ -43,7 +43,7 @@ def setUp(self): self.max_len = 5 self.pad_length = 7 self.pad_id = 0 - self.text_column = "text" + self.feature_names = "text" self.mock_tokenizer = MockTokenizer() self.source_data = [ {"text": "a b c"}, @@ -56,7 +56,7 @@ def setUp(self): def test_tokenize_and_trim(self): """Tests the 1:1 MapTransform (truncation) logic.""" trim_op = _grain_tokenizer.TokenizeAndTrim( - text_column=self.text_column, + feature_names=self.feature_names, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -79,7 +79,7 @@ def test_tokenize_and_trim(self): def test_tokenize_and_chunk(self): """Tests the 1:N FlatMapTransform (chunking) logic.""" chunk_op = _grain_tokenizer.TokenizeAndChunk( - text_column=self.text_column, + feature_names=self.feature_names, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -103,7 +103,7 @@ def test_tokenize_and_chunk(self): def test_trim_and_pad_chaining(self): """Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()""" trim_op = _grain_tokenizer.TokenizeAndTrim( - text_column=self.text_column, + feature_names=self.feature_names, sequence_length=self.max_len, add_bos=False, add_eos=False, @@ -130,7 +130,7 @@ def test_trim_and_pad_chaining(self): def test_chunk_and_pad_chaining(self): """Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()""" chunk_op = _grain_tokenizer.TokenizeAndChunk( - text_column=self.text_column, + feature_names=self.feature_names, sequence_length=self.max_len, add_bos=False, add_eos=False, From 374fe9542b788707f78bfd2503d6e7166c54d6e5 Mon Sep 17 00:00:00 2001 From: bzantium Date: Tue, 23 Sep 2025 22:36:22 +0900 Subject: [PATCH 09/13] merge changes from latest commmits --- tests/tokenizer_transform_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py index 550114144c..919778e68f 100644 --- a/tests/tokenizer_transform_test.py +++ b/tests/tokenizer_transform_test.py @@ -101,7 +101,7 @@ def test_tokenize_and_chunk(self): assert_array_equal(res, exp) def test_trim_and_pad_chaining(self): - """Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()""" + """Tests chaining TokenizeAndTrim.map() -> PadOrTrimToMaxLength.map()""" trim_op = _grain_tokenizer.TokenizeAndTrim( feature_names=self.feature_names, sequence_length=self.max_len, @@ -109,7 +109,7 @@ def test_trim_and_pad_chaining(self): add_eos=False, tokenizer=self.mock_tokenizer ) - pad_op = _input_pipeline_utils.PadToMaxLength( + pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( max_length=self.pad_length, pad_id=self.pad_id ) @@ -128,7 +128,7 @@ def test_trim_and_pad_chaining(self): assert_array_equal(res, exp) def test_chunk_and_pad_chaining(self): - """Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()""" + """Tests chaining TokenizeAndChunk.apply() -> PadOrTrimToMaxLength.map()""" chunk_op = _grain_tokenizer.TokenizeAndChunk( feature_names=self.feature_names, sequence_length=self.max_len, @@ -136,7 +136,7 @@ def test_chunk_and_pad_chaining(self): add_eos=False, tokenizer=self.mock_tokenizer ) - pad_op = _input_pipeline_utils.PadToMaxLength( + pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( max_length=self.pad_length, pad_id=self.pad_id ) From 6176a7a7c567dd576bb9071df621ce8ba9d63361 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 25 Sep 2025 08:21:56 +0900 Subject: [PATCH 10/13] update grain version --- requirements_with_jax_ai_image.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index d2fedb45e1..c51947435f 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -4,7 +4,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72 flax>=0.11.0 google-api-python-client google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip -grain[parquet]>=0.2.12 +grain[parquet] @ https://github.com/google/grain/archive/bab58eabf0c94d16002b5d82dda8d8320edd3c7b.zip jaxtyping jsonlines mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip From 570e36dd18413dd1c7b75db5d488c7dbf53c66c4 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 2 Oct 2025 10:53:54 +0900 Subject: [PATCH 11/13] use pypi grain version to use pre-compiled wheel and remove add_bos/eos since they are used at tokenizer itself not tokenizer trasform --- requirements_with_jax_ai_image.txt | 2 +- .../input_pipeline/_grain_data_processing.py | 19 ++++++++++++------- .../input_pipeline/_grain_tokenizer.py | 2 -- tests/tokenizer_transform_test.py | 4 ---- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 80356ce47e..8e9709d4fd 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -4,7 +4,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72 flax>=0.11.0 google-api-python-client google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip -grain[parquet] @ https://github.com/google/grain/archive/bab58eabf0c94d16002b5d82dda8d8320edd3c7b.zip +grain[parquet] jaxtyping jsonlines mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 0a3687b692..b24e78896a 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -115,15 +115,20 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra if tokenize: if config.use_truncation: dataset = dataset.map( - _grain_tokenizer.TokenizeAndTrim( - text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model - ) + _grain_tokenizer.TokenizeAndTrim( + text_column, config.max_target_length, tokenizer_model + ) ) else: - dataset = dataset.apply( + dataset = grain.experimental.WithOptionsIterDataset( + dataset, + options=grain.experimental.DatasetOptions() + ) + dataset = grain.experimental.apply_transformations( + dataset, _grain_tokenizer.TokenizeAndChunk( - text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model - ) + text_column, config.max_target_length, tokenizer_model + ) ) data_columns = ("inputs", "targets") @@ -183,7 +188,7 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo if tokenize: dataset = dataset.map( _grain_tokenizer.TokenizeAndTrim( - data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model + data_columns, config.max_target_length, tokenizer_model ) ) diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/MaxText/input_pipeline/_grain_tokenizer.py index ed9b2d8073..6e9177a945 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/MaxText/input_pipeline/_grain_tokenizer.py @@ -30,8 +30,6 @@ class TokenizerTransformBase: # pylint: disable=attribute-defined-outside-init feature_names: str | Sequence[str] sequence_length: int | Sequence[int] - add_bos: bool - add_eos: bool tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer def __post_init__(self): diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py index 919778e68f..1d30270801 100644 --- a/tests/tokenizer_transform_test.py +++ b/tests/tokenizer_transform_test.py @@ -58,8 +58,6 @@ def test_tokenize_and_trim(self): trim_op = _grain_tokenizer.TokenizeAndTrim( feature_names=self.feature_names, sequence_length=self.max_len, - add_bos=False, - add_eos=False, tokenizer=self.mock_tokenizer ) trim_ds = self.base_ds.map(trim_op) @@ -105,8 +103,6 @@ def test_trim_and_pad_chaining(self): trim_op = _grain_tokenizer.TokenizeAndTrim( feature_names=self.feature_names, sequence_length=self.max_len, - add_bos=False, - add_eos=False, tokenizer=self.mock_tokenizer ) pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( From b2bc5d3447c320eea93336341628b15ad23c8dba Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 2 Oct 2025 11:11:41 +0900 Subject: [PATCH 12/13] remove add_bos/eos from tokenizer transform --- tests/tokenizer_transform_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/tokenizer_transform_test.py b/tests/tokenizer_transform_test.py index 1d30270801..965c0ffbb2 100644 --- a/tests/tokenizer_transform_test.py +++ b/tests/tokenizer_transform_test.py @@ -79,8 +79,6 @@ def test_tokenize_and_chunk(self): chunk_op = _grain_tokenizer.TokenizeAndChunk( feature_names=self.feature_names, sequence_length=self.max_len, - add_bos=False, - add_eos=False, tokenizer=self.mock_tokenizer ) chunk_ds = self.base_ds.apply(chunk_op) @@ -128,8 +126,6 @@ def test_chunk_and_pad_chaining(self): chunk_op = _grain_tokenizer.TokenizeAndChunk( feature_names=self.feature_names, sequence_length=self.max_len, - add_bos=False, - add_eos=False, tokenizer=self.mock_tokenizer ) pad_op = _input_pipeline_utils.PadOrTrimToMaxLength( From cc90eb0c2d6e1a2bc60a9f3e6b20d4dc7445bbf2 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 2 Oct 2025 22:28:53 +0900 Subject: [PATCH 13/13] make as it is --- requirements_with_jax_ai_image.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 8e9709d4fd..53e62512fe 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -4,7 +4,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72 flax>=0.11.0 google-api-python-client google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip -grain[parquet] +grain[parquet]>=0.2.12 jaxtyping jsonlines mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip