From e59c63145cf7f6b3721495a13a6f13ba682b7ab6 Mon Sep 17 00:00:00 2001 From: Suhas Kotha Date: Thu, 6 Feb 2025 20:16:23 -0800 Subject: [PATCH] first pass at train val split implementation --- src/levanter/data/dataset.py | 56 ++++++++++ src/levanter/data/text.py | 48 +++++++-- tests/test_train_val_split.py | 187 ++++++++++++++++++++++++++++++++++ 3 files changed, 281 insertions(+), 10 deletions(-) create mode 100644 tests/test_train_val_split.py diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 86e8c78d6..954d122f5 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -125,6 +125,9 @@ def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None): return SlicedAsyncDataset(self, start_index, end_index) + def slice_proportionally(self, start_fraction: float, end_fraction: float): + return LazySlicedAsyncDataset(self, start_fraction, end_fraction) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -429,6 +432,59 @@ async def current_len(self) -> Optional[int]: return underlying_length +class LazySlicedAsyncDataset(AsyncDataset[T_co]): + """A dataset that lazily slices another dataset based on fractions of its length. + The actual slicing is deferred until the data is accessed.""" + + def __init__(self, dataset: AsyncDataset[T_co], start_fraction: float, end_fraction: float): + super().__init__() + self.dataset = dataset + self.start_fraction = start_fraction + self.end_fraction = end_fraction + # self._slice_indices: tuple[int, int] | None = None + # self.sliced_dataset: AsyncDataset[T_co] | None = None + + async def _initialize_slice_indices(self): + if self._slice_indices is not None: + if self.sliced_dataset is None: + raise ValueError("Failed to slice dataset") + if self._slice_indices is None: + raise ValueError("Failed to slice dataset") + return + + underlying_length = await self.dataset.async_len() + self.start_index = int(underlying_length * self.start_fraction) + self.end_index = int(underlying_length * self.end_fraction) + self._slice_indices = (self.start_index, self.end_index) + + if self.start_index == self.end_index: + raise ValueError("Dataset is empty") + + if self.start_index > self.end_index: + raise ValueError("Start index is greater than end index") + + self.sliced_dataset: AsyncDataset[T_co] = self.dataset.slice_dataset(self.start_index, self.end_index) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + await self._initialize_slice_indices() + return await self.sliced_dataset.get_batch(indices) + + async def async_len(self) -> int: + await self._initialize_slice_indices() + return self._slice_indices[1] - self._slice_indices[0] + + async def final_length_is_known(self) -> bool: + await self._initialize_slice_indices() + return await self.sliced_dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.sliced_dataset.is_finite() + + async def current_len(self) -> Optional[int]: + await self._initialize_slice_indices() + return await self.sliced_dataset.current_len() + + class BatchMappedAsyncDataset(AsyncDataset[U]): """ A dataset that applies a function to each batch of items in the dataset. diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 6446ad45f..d204b5ff3 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1011,9 +1011,27 @@ def mk_chat_sft_dataset( @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): - """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" + """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls. + Optionally supports splitting training data into train/validation sets.""" cache_dir: Optional[str] = "cache/" + split_fraction: Optional[float] = None + """If set, fraction of training data to use for training. Must be between 0 and 1. + The remainder will be used for validation. This overrides any existing validation set.""" + split_key: PRNGKeyArray = jax.random.PRNGKey(0) + + def __post_init__(self): + if self.split_fraction is not None: + if not 0 < self.split_fraction < 1: + raise ValueError(f"split_fraction must be between 0 and 1, got {self.split_fraction}") + + if self.split_key is None: + raise ValueError("split_key must be provided when split_fraction is set") + + if self._has_validation_set: + logger.warning( + "Dataset has an existing validation set - this will be ignored in favor of the split train set" + ) def train_set( self, @@ -1023,28 +1041,38 @@ def train_set( key: Optional[PRNGKeyArray] = None, epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: - - ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) - - # add epoch flag here. + # Get the dataset and handle None case upfront + ds = self.token_seq_dataset("train", seq_len, monitors) if ds is None: raise ValueError("No training set!") + if self.split_fraction is not None: + ds = ds.shuffle(self.split_key) # type: ignore + ds = ds.slice_proportionally(start_fraction=0, end_fraction=self.split_fraction) # type: ignore + if epochs: logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) + ds = EpochDataset(ds, max_epochs=epochs) # type: ignore if self.shuffle is True: - ds = ds.shuffle(key) + ds = ds.shuffle(key) # type: ignore elif isinstance(self.shuffle, int) and self.shuffle > 0: - ds = ds.era_shuffle(self.shuffle, key=key) + ds = ds.era_shuffle(self.shuffle, key=key) # type: ignore return ds # type: ignore def validation_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - return self.token_seq_dataset("validation", seq_len, monitors) + ) -> Optional[AsyncDataset[np.ndarray]]: + if self.split_fraction is not None: + ds: Optional[TokenSeqDataset] = self.token_seq_dataset("train", seq_len, monitors) + if ds is None: + return None + ds = ds.shuffle(self.split_key) # Use same key as train set for consistent split + ds = ds.slice_proportionally(start_fraction=self.split_fraction, end_fraction=1.0) # type: ignore + return ds # type: ignore + else: + return self.token_seq_dataset("validation", seq_len, monitors) def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True diff --git a/tests/test_train_val_split.py b/tests/test_train_val_split.py new file mode 100644 index 000000000..9d6fdb5e9 --- /dev/null +++ b/tests/test_train_val_split.py @@ -0,0 +1,187 @@ +import jax +import pytest + +from levanter.data import ListAsyncDataset +from levanter.data.text import LMDatasetConfig + + +@pytest.mark.asyncio +async def test_basic_split(): + """Test basic 80-20 split functionality""" + # Create a simple dataset + data = list(range(100)) + ds = ListAsyncDataset(data, is_complete=True) + + config = LMDatasetConfig( + split_fraction=0.8, + split_key=jax.random.PRNGKey(0), + ) + + # Mock the token_seq_dataset method to return our test dataset + config.token_seq_dataset = lambda split, seq_len, monitors: ds + + # Get train and validation sets + train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + val_ds = config.validation_set(seq_len=1) + + # Check lengths + train_indices = list(range(await train_ds.async_len())) + val_indices = list(range(await val_ds.async_len())) + + train_len = len(await train_ds.get_batch(train_indices)) + val_len = len(await val_ds.get_batch(val_indices)) + + assert train_len == 80 + assert val_len == 20 + assert train_len + val_len == len(data) + + +@pytest.mark.asyncio +async def test_disjoint_split(): + """Test that train and validation sets are disjoint""" + data = list(range(100)) + ds = ListAsyncDataset(data, is_complete=True) + + config = LMDatasetConfig( + split_fraction=0.8, + split_key=jax.random.PRNGKey(0), + ) + + config.token_seq_dataset = lambda split, seq_len, monitors: ds + + train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + val_ds = config.validation_set(seq_len=1) + + train_items = set(await train_ds.get_batch(list(range(await train_ds.async_len())))) + val_items = set(await val_ds.get_batch(list(range(await val_ds.async_len())))) + + print(train_items) + print(val_items) + + # Check sets are disjoint + assert len(train_items.intersection(val_items)) == 0 + # Check union covers all data + assert train_items.union(val_items) == set(data) + + +@pytest.mark.asyncio +async def test_deterministic_split(): + """Test that splits are deterministic with same key""" + data = list(range(100)) + ds = ListAsyncDataset(data, is_complete=True) + + key = jax.random.PRNGKey(0) + + # Create two configs with same key + config1 = LMDatasetConfig(split_fraction=0.8, split_key=key) + config2 = LMDatasetConfig(split_fraction=0.8, split_key=key) + + config1.token_seq_dataset = lambda split, seq_len, monitors: ds + config2.token_seq_dataset = lambda split, seq_len, monitors: ds + + # Get train sets from both configs + train_ds1 = config1.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + train_ds2 = config2.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + + train_items1 = await train_ds1.get_batch(list(range(await train_ds1.async_len()))) + train_items2 = await train_ds2.get_batch(list(range(await train_ds2.async_len()))) + + assert train_items1 == train_items2 + + +@pytest.mark.asyncio +async def test_different_keys_different_splits(): + """Test that different keys produce different splits""" + data = list(range(100)) + ds = ListAsyncDataset(data, is_complete=True) + + config1 = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(0)) + config2 = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(1)) + + config1.token_seq_dataset = lambda split, seq_len, monitors: ds + config2.token_seq_dataset = lambda split, seq_len, monitors: ds + + train_ds1 = config1.train_set(seq_len=1, key=jax.random.PRNGKey(2)) + train_ds2 = config2.train_set(seq_len=1, key=jax.random.PRNGKey(2)) + + train_items1 = await train_ds1.get_batch(list(range(await train_ds1.async_len()))) + train_items2 = await train_ds2.get_batch(list(range(await train_ds2.async_len()))) + + assert train_items1 != train_items2 + + +@pytest.mark.asyncio +async def test_edge_case_splits(): + """Test edge cases for split fractions""" + data = list(range(100)) + ds = ListAsyncDataset(data, is_complete=True) + + # Test with very small split + config = LMDatasetConfig(split_fraction=0.01, split_key=jax.random.PRNGKey(0)) + config.token_seq_dataset = lambda split, seq_len, monitors: ds + + train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + val_ds = config.validation_set(seq_len=1) + + train_len = len(await train_ds.get_batch(list(range(await train_ds.async_len())))) + val_len = len(await val_ds.get_batch(list(range(await val_ds.async_len())))) + + assert train_len == 1 + assert val_len == 99 + + # Test with very large split + config = LMDatasetConfig(split_fraction=0.99, split_key=jax.random.PRNGKey(0)) + config.token_seq_dataset = lambda split, seq_len, monitors: ds + + train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + val_ds = config.validation_set(seq_len=1) + + train_len = len(await train_ds.get_batch(list(range(await train_ds.async_len())))) + val_len = len(await val_ds.get_batch(list(range(await val_ds.async_len())))) + + assert train_len == 99 + assert val_len == 1 + + +def test_invalid_split_fractions(): + """Test that invalid split fractions raise appropriate errors""" + # Test split fraction = 0 + with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"): + LMDatasetConfig(split_fraction=0, split_key=jax.random.PRNGKey(0)) + + # Test split fraction = 1 + with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"): + LMDatasetConfig(split_fraction=1, split_key=jax.random.PRNGKey(0)) + + # Test negative split fraction + with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"): + LMDatasetConfig(split_fraction=-0.1, split_key=jax.random.PRNGKey(0)) + + # Test split fraction > 1 + with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"): + LMDatasetConfig(split_fraction=1.1, split_key=jax.random.PRNGKey(0)) + + +def test_missing_split_key(): + """Test that missing split key raises appropriate error""" + with pytest.raises(ValueError, match="split_key must be provided when split_fraction is set"): + LMDatasetConfig(split_fraction=0.8, split_key=None) + + +@pytest.mark.asyncio +async def test_empty_dataset(): + """Test splitting an empty dataset""" + data = [] + ds = ListAsyncDataset(data, is_complete=True) + + config = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(0)) + config.token_seq_dataset = lambda split, seq_len, monitors: ds + + train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1)) + val_ds = config.validation_set(seq_len=1) + + # Empty batch should raise ValueError + with pytest.raises(ValueError, match="Dataset is empty"): + await train_ds.get_batch([]) + with pytest.raises(ValueError, match="Dataset is empty"): + await val_ds.get_batch([])