Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Experiments #872

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
randomize_blocks: bool = True,
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
simulated_data_ratio: float = 1,
):
super().__init__()
if isinstance(weights, dict):
Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
raise NotImplementedError("Only restart strategy is supported for now.")

self.stop_strategy = stop_strategy
if simulated_data_ratio > 1:
raise ValueError(f"Simulated data ratio must be at most 1, got {simulated_data_ratio}")
self.simulated_data_ratio = simulated_data_ratio

# Initialize stage-related counts and IDs
(
Expand Down Expand Up @@ -275,7 +279,13 @@ async def _remap_indices(self, ds, indices_into_ds):
if self.stop_strategy == StopStrategy.RESTART_STRATEGY:
if ds.is_finite():
max_elem = max(indices_into_ds)
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
# Remap Indices Earlier when simulating epoching for a larger budget
if self.simulated_data_ratio < 1:
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = await ds.async_len()
length_of_dataset = int(true_length_of_dataset * self.simulated_data_ratio)
else:
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds]

return indices_into_ds
Expand Down
11 changes: 11 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig):
""" Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """

stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)

# Configuration for Simulated Epoching
target_budget: Optional[int] = None
experiment_budget: Optional[int] = None

mixture_block_size: int = 2048
""" Block size for deterministic mixing """

Expand Down Expand Up @@ -1226,12 +1231,18 @@ def shuffle_ds(ds, key):
out_token_datasets[name] = shuffle_ds(ds, next(key_iter))
token_datasets = out_token_datasets

if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
else:
simulated_data_ratio = 1

mixture = MixtureDataset(
datasets=token_datasets,
weights=self.train_weights,
stop_strategy=self.stop_strategy,
key=mix_key,
block_size=self.mixture_block_size,
simulated_data_ratio=simulated_data_ratio,
)

return mixture
Expand Down
32 changes: 32 additions & 0 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,38 @@ async def test_mixture_dataset_stop_strategy_restart():
await mixture_ds.async_len()


@pytest.mark.asyncio
async def test_mixture_dataset_simulated_data_size():
weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3}
mixture_ds = MixtureDataset(
datasets(),
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
simulated_data_ratio=0.2,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 10, 100] for item in batch)

mixture_ds = MixtureDataset(
datasets(),
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
simulated_data_ratio=0.4,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 2, 10, 20, 100, 200] for item in batch)


@pytest.mark.asyncio
async def test_mixture_dataset_normalized_weights():
weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5}
Expand Down
Loading