Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Experiments #872

Merged
merged 4 commits into from
Feb 1, 2025

Conversation

Helw150
Copy link
Collaborator

@Helw150 Helw150 commented Jan 30, 2025

image

Allows mixture datasets to specify a target budget and a experiment budget. This then computes what percentage of the data to sample overall in order to enable data constrained experiments like the above figure.

@Helw150 Helw150 requested review from dlwh and ahmeda14960 January 30, 2025 05:58
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about making a custom asyncdataset that is basically just a slice and leave this logic out of the mixture dataset?

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 30, 2025

what do you think about making a custom asyncdataset that is basically just a slice and leave this logic out of the mixture dataset?

With the idea of having it wrap the mixture dataset in the simulation case? Or each of the subdatasets of a mixture would be of this type?

I think the prior makes a lot of sense! I just need to wrap my head around the impl. that would lead to consistent slices of each sub-dataset as well.

The latter seems like it would require more configuration wiring that might not be worth it.

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 30, 2025

Looking at it more, I don't know if there's actually a clean way to make it work for the prior.

The latter feels fine as well though and likely has less footguns with respect to the effective dataset size being different than all the different utilities so I'll get that in.

@dlwh
Copy link
Member

dlwh commented Jan 30, 2025

i was thinking the latter yeah. basically just "if subsample: datasets = map(datasets, slice_dataset)" but I guess your point is we don't know how many samples we're going to go through until we're in the mixture dataset

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 30, 2025

I've got a rework I'll push after NLP lunch that I feel like is probably cleaner.

@dlwh
Copy link
Member

dlwh commented Jan 30, 2025

actually no i don't understand why isn't it just:

class SimulatedDataRatioDataset:
  async def wait_until_len_at_least(len):
          target_for_len = len * self.ratio_target
          actual = await self.ds.wait_until_len_at_least(int(target_for_len))

          return int(actual * self.ratio_target))

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 30, 2025

Ok - how does this look?

The test is probably in the wrong spot though realistically?

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 30, 2025

actually no i don't understand why isn't it just:

class SimulatedDataRatioDataset:
  async def wait_until_len_at_least(len):
          target_for_len = len * self.ratio_target
          actual = await self.ds.wait_until_len_at_least(int(target_for_len))

          return int(actual * self.ratio_target))

Just saw this! My one concern with this v.s. what I have in the current one is that it makes more assumptions about the implementation of the underlying get_batch code.

(edit: for example, you couldn't wrap a mixture in this class right?)

@dlwh
Copy link
Member

dlwh commented Jan 31, 2025

i don't see why not?

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 31, 2025

actually no i don't understand why isn't it just:

class SimulatedDataRatioDataset:
  async def wait_until_len_at_least(len):
          target_for_len = len * self.ratio_target
          actual = await self.ds.wait_until_len_at_least(int(target_for_len))

          return int(actual * self.ratio_target))

Maybe I'm missing something, but this assumes the underlying dataset uses get_batch logic which depends on it's own wait_until_len_at_least to determine the dataset is exhausted.

For a mixture dataset, this isn't quite true - it instead uses the wait_until_len_at_least of each of the sub-datasets. Since overriding the parents call doesn't affect these, it wouldn't change the epoching behavior right?

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 31, 2025

Oop just saw the slice version I have on my branch isn't pushed here. I think that'll be less mixed concerns than the current PR

@Helw150
Copy link
Collaborator Author

Helw150 commented Jan 31, 2025

Ok - pushed. I realized I hadn't pushed before since this currently has a pre-commit type failure which I'll fix assuming the core logic makes sense to you!

The core difference between this and your proposal is overriding get_batch rather than overriding the length waiting function.

    async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
        shifted_indices = [(index + self.start_index) for index in indices]
        max_index = max(shifted_indices)

        if self.end_index is not None and max_index > self.end_index:
            raise ValueError("Requested indices beyond the end of the dataset")

        return await self.dataset.get_batch(indices)

I prefer this because the effects of calling the slice (to me) are clearer for the above code without checking anything else, whereas I feel like understanding the effects of the wait_until_len_at_least override feels like it would require knowing how the get_batch of dataset you are wrapping is implemented for it to make sense.

LMK what you think!

simulated_data_ratio = self.experiment_budget / self.target_budget
for name, ds in token_datasets.items():
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = len(ds.as_sync_dataset())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the thing i was trying to work around but I agree it's not worth fixing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

@dlwh
Copy link
Member

dlwh commented Jan 31, 2025

can you fix the mypy errors then good to merge

@Helw150
Copy link
Collaborator Author

Helw150 commented Feb 1, 2025

can you fix the mypy errors then good to merge

Fixed!

@dlwh dlwh merged commit 1d216d1 into main Feb 1, 2025
7 of 8 checks passed
@dlwh dlwh deleted the will/constrain branch February 1, 2025 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants