Skip to content

Commit 387815d

Browse files
committed
Add merge datasets tutorial
Signed-off-by: asolergi-nv <asolergibert@nvidia.com>
1 parent 9a2c972 commit 387815d

2 files changed

Lines changed: 299 additions & 0 deletions

File tree

tutorials/text/megatron-tokenizer/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,25 @@ Following the header, the file contains **20 bytes per document** structured as
5252
- **8 bytes**: The document index
5353

5454
For more details about Megatron's DataLoading solution and tokenization pipeline refer to [`megatron.core.datasets`](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets) README.
55+
56+
## Merge multiple file prefixes
57+
58+
The `MegatronTokenizerWriter` will generate one file prefix per worker. For large-scale training runs with Megatron-LM, it is recommended to consolidate multiple file prefixes into larger files (typically 5-100GB each) to reduce the number of file accesses during training.
59+
60+
To address this, we provide the `merge_datasets.py` script in this tutorial directory. This script is a simplified version of the [`tools/merge_datasets.py`](https://github.com/NVIDIA/Megatron-LM/blob/main/tools/merge_datasets.py) script found in the Megatron-LM repository, with the key advantage that it does not require any Megatron dependencies.
61+
62+
### Usage
63+
64+
The script takes a directory containing multiple file prefixes and merges them into a single output prefix:
65+
66+
```bash
67+
python tutorials/text/megatron-tokenizer/merge_datasets.py \
68+
--input-dir /path/to/tokenized/files \
69+
--output-prefix /path/to/output/merged
70+
```
71+
72+
**Arguments:**
73+
- `--input-dir`: Path to the directory containing all the tokenized file prefixes (`.bin` and `.idx` pairs) to merge
74+
- `--output-prefix`: Path and filename prefix for the merged output (will create `<output-prefix>.bin` and `<output-prefix>.idx`)
75+
76+
The script automatically detects all valid file prefix pairs in the input directory and merges them in sorted order, producing a single consolidated file prefix ready for use with Megatron-LM.
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""
2+
Simplified version of the tools/merge_datasets.py script from the Megatron-LM library.
3+
"""
4+
5+
import argparse
6+
import gc
7+
import os
8+
import shutil
9+
import struct
10+
from collections.abc import Iterable
11+
from types import TracebackType
12+
13+
import numpy as np
14+
15+
_INDEX_HEADER = b"MMIDIDX\x00\x00"
16+
17+
18+
def extract_index_contents(idx_path: str) -> tuple[np.ndarray, np.ndarray, type[np.number]]:
19+
"""Extract the index contents from the index file
20+
21+
Args:
22+
idx_path (str): The path to the index file
23+
24+
Returns:
25+
Tuple[np.ndarray, np.ndarray, Type[np.number]]: The sequence lengths, document indices and dtype
26+
of the index file
27+
"""
28+
with open(idx_path, "rb") as stream:
29+
header = stream.read(9)
30+
assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}" # noqa: S101
31+
32+
version = struct.unpack("<Q", stream.read(8))[0]
33+
assert version == 1, f"bad version, cannot read: {idx_path}" # noqa: S101
34+
35+
code = struct.unpack("<B", stream.read(1))[0]
36+
dtype = np.int32 if code == 4 else np.uint16 # noqa: PLR2004
37+
38+
sequence_count = struct.unpack("<Q", stream.read(8))[0]
39+
document_count = struct.unpack("<Q", stream.read(8))[0]
40+
41+
offset = stream.tell()
42+
43+
bin_buffer_mmap = np.memmap(idx_path, mode="r", order="C")
44+
bin_buffer = memoryview(bin_buffer_mmap)
45+
46+
sequence_lengths = np.frombuffer(bin_buffer, dtype=np.int32, count=sequence_count, offset=offset)
47+
48+
sequence_pointers = np.frombuffer(
49+
bin_buffer,
50+
dtype=np.int64,
51+
count=sequence_count,
52+
offset=offset + sequence_lengths.nbytes,
53+
)
54+
document_indices = np.frombuffer(
55+
bin_buffer,
56+
dtype=np.int64,
57+
count=document_count,
58+
offset=offset + sequence_lengths.nbytes + sequence_pointers.nbytes,
59+
)
60+
61+
return sequence_lengths, document_indices, dtype
62+
63+
64+
class _IndexWriter:
65+
"""Simplified version of the _IndexWriter class from the Megatron-LM library.
66+
67+
Object class to write the index (.idx) file
68+
69+
Args:
70+
idx_path (str): The path to the index file
71+
72+
dtype (Type[np.number]): The dtype of the index file
73+
"""
74+
75+
def __init__(self, idx_path: str, dtype: type[np.number]) -> None:
76+
self.idx_path = idx_path
77+
self.dtype = dtype
78+
79+
def __enter__(self) -> "_IndexWriter":
80+
"""Enter the context introduced by the 'with' keyword
81+
82+
Returns:
83+
_IndexWriter: The instance
84+
"""
85+
self.idx_writer = open(self.idx_path, "wb")
86+
# fixed, vestigial practice
87+
self.idx_writer.write(_INDEX_HEADER)
88+
# fixed, vestigial practice
89+
self.idx_writer.write(struct.pack("<Q", 1))
90+
# the numeric code for the dtype
91+
self.idx_writer.write(struct.pack("<B", 4 if self.dtype == np.int32 else 8))
92+
return self
93+
94+
def __exit__(
95+
self,
96+
exc_type: type[BaseException] | None,
97+
exc_val: BaseException | None,
98+
exc_tb: TracebackType | None,
99+
) -> bool | None:
100+
"""Exit the context introduced by the 'with' keyword
101+
102+
Args:
103+
exc_type (Optional[Type[BaseException]]): Exception type
104+
105+
exc_val (Optional[BaseException]): Exception value
106+
107+
exc_tb (Optional[TracebackType]): Exception traceback object
108+
109+
Returns:
110+
Optional[bool]: Whether to silence the exception
111+
"""
112+
self.idx_writer.close()
113+
return None
114+
115+
def write(
116+
self,
117+
sequence_lengths: Iterable[int | np.integer],
118+
document_indices: Iterable[int | np.integer],
119+
) -> None:
120+
"""Write the index (.idx) file
121+
122+
Args:
123+
sequence_lengths (List[int]): The length of each sequence
124+
125+
document_indices (List[int]): The seqyebce indices demarcating the end of each document
126+
"""
127+
sequence_pointers = self._sequence_pointers(sequence_lengths)
128+
129+
# the number of sequences in the dataset
130+
sequence_count = len(sequence_lengths)
131+
self.idx_writer.write(struct.pack("<Q", sequence_count))
132+
133+
# the number of documents in the dataset
134+
document_count = len(document_indices)
135+
self.idx_writer.write(struct.pack("<Q", document_count))
136+
137+
# the number of tokens per sequence
138+
self.idx_writer.write(np.array(sequence_lengths, dtype=np.int32).tobytes(order="C"))
139+
140+
# the byte offsets for all sequences
141+
self.idx_writer.write(np.array(sequence_pointers, dtype=np.int64).tobytes(order="C"))
142+
143+
# the sequence indices marking the end of each document
144+
self.idx_writer.write(np.array(document_indices, dtype=np.int64).tobytes(order="C"))
145+
146+
def _sequence_pointers(self, sequence_lengths: Iterable[int | np.integer]) -> list[int]:
147+
"""Build the sequence pointers per the sequence lengths and dtype size
148+
149+
Args:
150+
sequence_lengths (List[int]): The length of each sequence
151+
152+
Returns:
153+
List[int]: The pointer to the beginning of each sequence
154+
"""
155+
itemsize = np.int64(4 if self.dtype == np.int32 else 2)
156+
curr_ptr = np.int64(0)
157+
list_ptr = []
158+
for length in sequence_lengths:
159+
list_ptr.append(curr_ptr.item())
160+
curr_ptr += length * itemsize
161+
return list_ptr
162+
163+
164+
class IndexedDatasetBuilder:
165+
"""Simplified version of the IndexedDatasetBuilder class from the Megatron-LM library.
166+
167+
Builder class for the IndexedDataset class
168+
169+
Args:
170+
bin_path (str): The path to the data (.bin) file
171+
172+
dtype (Type[np.number], optional): The dtype of the index file. Defaults to np.int32.
173+
174+
"""
175+
176+
def __init__(self, bin_path: str, dtype: type[np.number]) -> None:
177+
self.data_file = open(bin_path, "wb") # noqa: SIM115
178+
self.dtype = dtype
179+
180+
self.sequence_lengths = []
181+
self.document_indices = [0]
182+
183+
def add_index(self, path_prefix: str) -> None:
184+
"""Add an entire IndexedDataset to the dataset
185+
186+
Args:
187+
path_prefix (str): The index (.idx) and data (.bin) prefix
188+
"""
189+
# Concatenate index
190+
sequence_lengths, document_indices, dtype = extract_index_contents(path_prefix + ".idx")
191+
assert dtype == self.dtype # noqa: S101
192+
193+
offset = len(self.sequence_lengths)
194+
self.sequence_lengths.extend(sequence_lengths)
195+
self.document_indices.extend((offset + document_indices)[1:])
196+
197+
# Free up memory to make space for new indices
198+
del sequence_lengths, document_indices
199+
gc.collect()
200+
201+
# Concatenate data
202+
with open(path_prefix + ".bin", "rb") as f:
203+
shutil.copyfileobj(f, self.data_file)
204+
205+
def finalize(self, idx_path: str) -> None:
206+
"""Clean up and write the index (.idx) file
207+
208+
Args:
209+
idx_path (str): The path to the index file
210+
"""
211+
self.data_file.close()
212+
with _IndexWriter(idx_path, self.dtype) as writer:
213+
writer.write(self.sequence_lengths, self.document_indices)
214+
215+
216+
def get_args() -> argparse.Namespace:
217+
parser = argparse.ArgumentParser()
218+
219+
group = parser.add_argument_group(title="input data")
220+
group.add_argument(
221+
"--input-dir",
222+
type=str,
223+
required=True,
224+
help="Path to directory containing all document files to merge",
225+
)
226+
227+
group = parser.add_argument_group(title="output data")
228+
group.add_argument(
229+
"--output-prefix",
230+
type=str,
231+
required=True,
232+
help="Path to merged output file prefix",
233+
)
234+
235+
args = parser.parse_args()
236+
237+
assert os.path.isdir(args.input_dir), f"ERROR: {args.input_dir} is not a directory or does not exist" # noqa: S101
238+
239+
assert os.path.isdir(os.path.dirname(args.output_prefix)), ( # noqa: S101
240+
f"ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist"
241+
)
242+
243+
return args
244+
245+
246+
def main(input_dir: str, output_prefix: str) -> None:
247+
prefixes = set()
248+
for basename in os.listdir(input_dir):
249+
prefix, ext = os.path.splitext(basename)
250+
251+
if prefix in prefixes:
252+
continue
253+
254+
if not os.path.isfile(os.path.join(input_dir, basename)):
255+
continue
256+
257+
ext_pair = ".bin" if ext == ".idx" else ".idx"
258+
assert os.path.isfile(os.path.join(input_dir, prefix) + ext_pair), ( # noqa: S101
259+
f"ERROR: {ext_pair} file not provided for {os.path.join(input_dir, prefix)}"
260+
)
261+
262+
prefixes.add(prefix)
263+
264+
builder = None
265+
for prefix in sorted(prefixes):
266+
if builder is None:
267+
_, _, dtype = extract_index_contents(os.path.join(input_dir, prefix + ".idx"))
268+
builder = IndexedDatasetBuilder(output_prefix + ".bin", dtype=dtype)
269+
270+
builder.add_index(os.path.join(input_dir, prefix))
271+
272+
builder.finalize(output_prefix + ".idx")
273+
274+
275+
if __name__ == "__main__":
276+
args = get_args()
277+
main(args.input_dir, args.output_prefix)

0 commit comments

Comments
 (0)