Skip to content

Commit 5c242b4

Browse files
authored
Add feature to slice, subsample and split dataset (#161)
1 parent b51b597 commit 5c242b4

25 files changed

+900
-124
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,7 @@ lightning_logs
112112

113113
# Ruff
114114
.ruff_cache/
115+
116+
117+
# status.json file
118+
status.json

README.md

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ dataloader = StreamingDataLoader(dataset)
121121
# Key Features
122122

123123
- [Multi-GPU / Multi-Node Support](#multi-gpu--multi-node-support)
124+
- [Subsample and split your datasets](#access-any-item)
124125
- [Access any item](#access-any-item)
125126
- [Use any data transforms](#use-any-data-transforms)
126127
- [The Map Operator](#the-map-operator)
@@ -131,6 +132,7 @@ dataloader = StreamingDataLoader(dataset)
131132
- [Configure Cache Size Limit](#configure-cache-size-limit)
132133
- [On-Prem Optimizations](#on-prem-optimizations)
133134

135+
134136
## Multi-GPU / Multi-Node Support
135137

136138
The `StreamingDataset` and `StreamingDataLoader` automatically make sure each rank receives the same quantity of varied batches of data, so it works out of the box with your favorite frameworks ([PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), or [PyTorch](https://pytorch.org/docs/stable/index.html)) to do distributed training.
@@ -139,6 +141,41 @@ Here you can see an illustration showing how the Streaming Dataset works with mu
139141

140142
![An illustration showing how the Streaming Dataset works with multi node.](https://pl-flash-data.s3.amazonaws.com/streaming_dataset.gif)
141143

144+
## Subsample and split your datasets
145+
146+
You can split your dataset with more ease with `train_test_split`.
147+
148+
```python
149+
from litdata import StreamingDataset, train_test_split
150+
151+
dataset = StreamingDataset("s3://my-bucket/my-data") # data are stored in the cloud
152+
153+
print(len(dataset)) # display the length of your data
154+
# out: 100,000
155+
156+
train_dataset, val_dataset, test_dataset = train_test_split(dataset, splits=[0.3, 0.2, 0.5])
157+
158+
print(train_dataset)
159+
# out: 30,000
160+
161+
print(val_dataset)
162+
# out: 20,000
163+
164+
print(test_dataset)
165+
# out: 50,000
166+
```
167+
168+
Or simply subsample them
169+
170+
```python
171+
from litdata import StreamingDataset, train_test_split
172+
173+
dataset = StreamingDataset("s3://my-bucket/my-data", subsample=0.01) # data are stored in the cloud
174+
175+
print(len(dataset)) # display the length of your data
176+
# out: 1000
177+
```
178+
142179
## Access any item
143180

144181
Access the data you need, whenever you need it, regardless of where it is stored.
@@ -209,8 +246,7 @@ Easily experiment with dataset mixtures using the `CombinedStreamingDataset` cla
209246
As an example, this mixture of [Slimpajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) & [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata) was used in the [TinyLLAMA](https://github.com/jzhang38/TinyLlama) project to pretrain a 1.1B Llama model on 3 trillion tokens.
210247

211248
```python
212-
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
213-
from litdata.streaming.item_loader import TokensLoader
249+
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader, TokensLoader
214250
from tqdm import tqdm
215251
import os
216252

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch
22
filelock
3-
numpy
3+
numpy < 2.0.0
44
boto3
55
requests

src/litdata/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
from litdata.streaming.combined import CombinedStreamingDataset
1818
from litdata.streaming.dataloader import StreamingDataLoader
1919
from litdata.streaming.dataset import StreamingDataset
20+
from litdata.streaming.item_loader import TokensLoader
21+
from litdata.utilities.train_test_split import train_test_split
2022

2123
__all__ = [
2224
"StreamingDataset",
2325
"CombinedStreamingDataset",
2426
"StreamingDataLoader",
27+
"TokensLoader",
2528
"map",
2629
"optimize",
2730
"walk",
31+
"train_test_split",
2832
]
2933
if RequirementCache("lightning_sdk"):
3034
from lightning_sdk import Machine # noqa: F401

src/litdata/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@
6363

6464
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
6565
_IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None))
66+
_ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0")))

src/litdata/processing/data_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from litdata.constants import (
3636
_BOTO3_AVAILABLE,
3737
_DEFAULT_FAST_DEV_RUN_ITEMS,
38+
_ENABLE_STATUS,
3839
_INDEX_FILENAME,
3940
_IS_IN_STUDIO,
4041
_LIGHTNING_CLOUD_AVAILABLE,
@@ -995,7 +996,7 @@ def run(self, data_recipe: DataRecipe) -> None:
995996
if current_total == num_items:
996997
break
997998

998-
if _IS_IN_STUDIO and node_rank == 0:
999+
if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
9991000
with open("status.json", "w") as f:
10001001
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)
10011002

src/litdata/streaming/cache.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from litdata.constants import (
1919
_INDEX_FILENAME,
2020
)
21-
from litdata.streaming.item_loader import BaseItemLoader
21+
from litdata.streaming.item_loader import BaseItemLoader, Interval
2222
from litdata.streaming.reader import BinaryReader
2323
from litdata.streaming.resolver import Dir, _resolve_dir
2424
from litdata.streaming.sampler import ChunkedIndex
@@ -34,6 +34,8 @@ class Cache:
3434
def __init__(
3535
self,
3636
input_dir: Optional[Union[str, Dir]],
37+
subsampled_files: Optional[List[str]] = None,
38+
region_of_interest: Optional[List[Tuple[int, int]]] = None,
3739
compression: Optional[str] = None,
3840
chunk_size: Optional[int] = None,
3941
chunk_bytes: Optional[Union[int, str]] = None,
@@ -46,6 +48,8 @@ def __init__(
4648
4749
Arguments:
4850
input_dir: The path to where the chunks will be or are stored.
51+
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
52+
region_of_interest: List of tuples of (start,end) of region of interest for each chunk.
4953
compression: The name of the algorithm to reduce the size of the chunks.
5054
chunk_bytes: The maximum number of bytes within a chunk.
5155
chunk_size: The maximum number of items within a chunk.
@@ -67,6 +71,8 @@ def __init__(
6771
)
6872
self._reader = BinaryReader(
6973
self._cache_dir,
74+
subsampled_files=subsampled_files,
75+
region_of_interest=region_of_interest,
7076
max_cache_size=_convert_bytes_to_int(max_cache_size) if isinstance(max_cache_size, str) else max_cache_size,
7177
remote_input_dir=input_dir.url,
7278
compression=compression,
@@ -138,7 +144,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
138144
def __len__(self) -> int:
139145
return self._reader.get_length()
140146

141-
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
147+
def get_chunk_intervals(self) -> List[Interval]:
142148
return self._reader.get_chunk_intervals()
143149

144150
def _get_chunk_index_from_index(self, index: int) -> int:

src/litdata/streaming/config.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from litdata.constants import _INDEX_FILENAME
1919
from litdata.streaming.compression import _COMPRESSORS, Compressor
2020
from litdata.streaming.downloader import get_downloader_cls
21-
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
21+
from litdata.streaming.item_loader import BaseItemLoader, Interval, PyTreeLoader, TokensLoader
2222
from litdata.streaming.sampler import ChunkedIndex
2323
from litdata.streaming.serializers import Serializer
2424
from litdata.utilities._pytree import tree_unflatten, treespec_loads
@@ -31,6 +31,8 @@ def __init__(
3131
serializers: Dict[str, Serializer],
3232
remote_dir: Optional[str],
3333
item_loader: Optional[BaseItemLoader] = None,
34+
subsampled_files: Optional[List[str]] = None,
35+
region_of_interest: Optional[List[Tuple[int, int]]] = None,
3436
) -> None:
3537
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
3638
chunk.
@@ -40,24 +42,34 @@ def __init__(
4042
serializers: The serializers used to serialize and deserialize the chunks.
4143
remote_dir: The path to a remote folder where the data are located.
4244
The scheme needs to be added to the path.
45+
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
46+
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
4347
4448
"""
4549
self._cache_dir = cache_dir
46-
self._intervals: List[Tuple[int, int]] = []
50+
self._intervals: List[Interval] = []
4751
self._config = None
48-
self._chunks = []
52+
self._chunks = None
4953
self._remote_dir = remote_dir
5054
self._item_loader = item_loader or PyTreeLoader()
5155

5256
with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
5357
data = json.load(f)
58+
_original_chunks = data["chunks"]
5459
self._config = data["config"]
5560
self._validate_item_loader()
56-
self._chunks.extend(data["chunks"])
61+
62+
assert _original_chunks is not None
63+
64+
if subsampled_files is None:
65+
self._chunks = _original_chunks
66+
else:
67+
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
5768

5869
self._config["data_spec"] = treespec_loads(self._config["data_spec"])
5970

60-
self._item_loader.setup(self._config, self._chunks, serializers)
71+
assert self._chunks is not None
72+
self._item_loader.setup(self._config, self._chunks, serializers, region_of_interest)
6173
self._intervals = self._item_loader.generate_intervals()
6274
self._length = self._intervals[-1][-1]
6375
self._downloader = None
@@ -87,6 +99,7 @@ def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) ->
8799
self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion
88100

89101
def download_chunk_from_index(self, chunk_index: int) -> None:
102+
assert self._chunks is not None
90103
chunk_filename = self._chunks[chunk_index]["filename"]
91104

92105
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
@@ -124,7 +137,7 @@ def try_decompress(self, local_chunkpath: str) -> None:
124137
f.write(data)
125138

126139
@property
127-
def intervals(self) -> List[Tuple[int, int]]:
140+
def intervals(self) -> List[Interval]:
128141
if self._intervals is None:
129142
raise RuntimeError("The intervals should be defined.")
130143
return self._intervals
@@ -133,6 +146,7 @@ def intervals(self) -> List[Tuple[int, int]]:
133146
def num_bytes(self) -> int:
134147
if self._config is None:
135148
raise RuntimeError("The config should be defined.")
149+
assert self._chunks is not None
136150
return sum(c["chunk_bytes"] for c in self._chunks)
137151

138152
@property
@@ -167,14 +181,15 @@ def config(self) -> Dict[str, Any]:
167181

168182
def _get_chunk_index_from_index(self, index: int) -> int:
169183
for chunk_index, internal in enumerate(self._intervals):
170-
if internal[0] <= index < internal[1]:
184+
if internal[0] <= index < internal[-1]:
171185
return chunk_index
172186
raise ValueError(
173187
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
174188
)
175189

176190
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
177191
"""Find the associated chunk metadata."""
192+
assert self._chunks is not None
178193
chunk = self._chunks[index.chunk_index]
179194

180195
local_chunkpath = os.path.join(self._cache_dir, chunk["filename"])
@@ -188,6 +203,7 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
188203

189204
def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
190205
"""Retrieves the associated chunk_index for a given chunk filename."""
206+
assert self._chunks is not None
191207
for chunk_index, chunk in enumerate(self._chunks):
192208
if chunk["filename"] == chunk_filename:
193209
return chunk_index
@@ -200,6 +216,8 @@ def load(
200216
serializers: Dict[str, Serializer],
201217
remote_dir: Optional[str] = None,
202218
item_loader: Optional[BaseItemLoader] = None,
219+
subsampled_files: Optional[List[str]] = None,
220+
region_of_interest: Optional[List[Tuple[int, int]]] = None,
203221
) -> Optional["ChunksConfig"]:
204222
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
205223

@@ -210,7 +228,7 @@ def load(
210228
if not os.path.exists(cache_index_filepath):
211229
return None
212230

213-
return ChunksConfig(cache_dir, serializers, remote_dir, item_loader)
231+
return ChunksConfig(cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest)
214232

215233
def __len__(self) -> int:
216234
return self._length
@@ -223,3 +241,32 @@ def _validate_item_loader(self) -> None:
223241
and not isinstance(self._item_loader, TokensLoader)
224242
):
225243
raise ValueError("Please, use Cache(..., item_loader=TokensLoader(block_size=...))")
244+
245+
246+
def load_subsampled_chunks(subsampled_files: List[str], original_chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
247+
"""Loads Chunks based on subsample provided."""
248+
_subsampled_chunks: List[Dict[str, Any]] = [{} for _ in range(len(subsampled_files))]
249+
250+
assert len(_subsampled_chunks) == len(subsampled_files)
251+
252+
filename_dict = {}
253+
254+
# Populate the dictionary with filenames and their indices
255+
for index, filename in enumerate(subsampled_files):
256+
filename_dict[filename] = index
257+
258+
for curr_chunk in original_chunks:
259+
if curr_chunk["filename"] in filename_dict:
260+
idx = filename_dict[curr_chunk["filename"]]
261+
_subsampled_chunks[idx] = curr_chunk
262+
263+
# if any idx of _subsampled_chunks is None, means,
264+
# some elements in subsampled_files were not actually part of chunks
265+
# raise error
266+
if any(not _subsampled_chunk for _subsampled_chunk in _subsampled_chunks):
267+
raise ValueError(
268+
"Mismatch in subsampled files and the chunks loaded",
269+
"Make sure subsampled chunks are actually part of the original chunk",
270+
)
271+
272+
return _subsampled_chunks

0 commit comments

Comments
 (0)