Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Experiments #872

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa
def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
return SlicedAsyncDataset(self, start_index, end_index)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -375,6 +378,57 @@ def _call_fn(self, index, item):
return self.fn(item, *self._extra_args, **kwargs)


class SlicedAsyncDataset(AsyncDataset[U]):
def __init__(
self,
dataset: AsyncDataset[U],
start_index: Optional[int] = None,
end_index: Optional[int] = None,
):
super().__init__()
if start_index is None:
start_index = 0
if end_index is not None and start_index > end_index:
raise ValueError("End index must come after start index.")

self.start_index = start_index
self.end_index = end_index
self.dataset = dataset
self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[U]:
shifted_indices = [(index + self.start_index) for index in indices]
max_index = max(shifted_indices)

if self.end_index is not None and max_index > self.end_index:
raise ValueError("Requested indices beyond the end of the dataset")

return await self.dataset.get_batch(shifted_indices)

async def async_len(self) -> int:
underlying_length = await self.dataset.async_len()
if self.end_index is None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index

async def final_length_is_known(self) -> bool:
underlying_is_known = await self.dataset.final_length_is_known()
return underlying_is_known and self.end_index is not None

def is_finite(self) -> bool:
return self.dataset.is_finite() and self.end_index is not None

async def current_len(self) -> Optional[int]:
underlying_length = await self.dataset.current_len()
if self.end_index is not None:
return self.end_index - self.start_index
elif underlying_length is not None:
return underlying_length - self.start_index
else:
return underlying_length


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
Expand Down
22 changes: 22 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig):
""" Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """

stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)

# Configuration for Simulated Epoching
target_budget: Optional[int] = None
experiment_budget: Optional[int] = None

mixture_block_size: int = 2048
""" Block size for deterministic mixing """

Expand Down Expand Up @@ -1226,6 +1231,23 @@ def shuffle_ds(ds, key):
out_token_datasets[name] = shuffle_ds(ds, next(key_iter))
token_datasets = out_token_datasets

if (
self.experiment_budget is not None and self.target_budget is not None
) and self.experiment_budget > self.target_budget:
raise ValueError(
f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >"
f" {self.target_budget}"
)
if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
sliced_token_datasets: Dict[str, TokenSeqDataset] = {}
for name, ds in token_datasets.items():
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = len(ds.as_sync_dataset())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the thing i was trying to work around but I agree it's not worth fixing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio)
sliced_token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset)
token_datasets = sliced_token_datasets

mixture = MixtureDataset(
datasets=token_datasets,
weights=self.train_weights,
Expand Down
30 changes: 30 additions & 0 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ async def test_mixture_dataset_stop_strategy_restart():
await mixture_ds.async_len()


@pytest.mark.asyncio
async def test_mixture_dataset_simulated_data_size():
weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3}
mixture_ds = MixtureDataset(
{name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 10, 100] for item in batch)

mixture_ds = MixtureDataset(
{name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 2, 10, 20, 100, 200] for item in batch)


@pytest.mark.asyncio
async def test_mixture_dataset_normalized_weights():
weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5}
Expand Down
Loading