Skip to content

Commit

Permalink
first pass at train val split implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kothasuhas committed Feb 7, 2025
1 parent 1927f93 commit e59c631
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 10 deletions.
56 changes: 56 additions & 0 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs)
def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
return SlicedAsyncDataset(self, start_index, end_index)

def slice_proportionally(self, start_fraction: float, end_fraction: float):
return LazySlicedAsyncDataset(self, start_fraction, end_fraction)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -429,6 +432,59 @@ async def current_len(self) -> Optional[int]:
return underlying_length


class LazySlicedAsyncDataset(AsyncDataset[T_co]):
"""A dataset that lazily slices another dataset based on fractions of its length.
The actual slicing is deferred until the data is accessed."""

def __init__(self, dataset: AsyncDataset[T_co], start_fraction: float, end_fraction: float):
super().__init__()
self.dataset = dataset
self.start_fraction = start_fraction
self.end_fraction = end_fraction
# self._slice_indices: tuple[int, int] | None = None
# self.sliced_dataset: AsyncDataset[T_co] | None = None

async def _initialize_slice_indices(self):
if self._slice_indices is not None:
if self.sliced_dataset is None:
raise ValueError("Failed to slice dataset")
if self._slice_indices is None:
raise ValueError("Failed to slice dataset")
return

underlying_length = await self.dataset.async_len()
self.start_index = int(underlying_length * self.start_fraction)
self.end_index = int(underlying_length * self.end_fraction)
self._slice_indices = (self.start_index, self.end_index)

if self.start_index == self.end_index:
raise ValueError("Dataset is empty")

if self.start_index > self.end_index:
raise ValueError("Start index is greater than end index")

self.sliced_dataset: AsyncDataset[T_co] = self.dataset.slice_dataset(self.start_index, self.end_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
await self._initialize_slice_indices()
return await self.sliced_dataset.get_batch(indices)

async def async_len(self) -> int:
await self._initialize_slice_indices()
return self._slice_indices[1] - self._slice_indices[0]

async def final_length_is_known(self) -> bool:
await self._initialize_slice_indices()
return await self.sliced_dataset.final_length_is_known()

def is_finite(self) -> bool:
return self.sliced_dataset.is_finite()

async def current_len(self) -> Optional[int]:
await self._initialize_slice_indices()
return await self.sliced_dataset.current_len()


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
Expand Down
48 changes: 38 additions & 10 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,9 +1011,27 @@ def mk_chat_sft_dataset(

@dataclass
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls.
Optionally supports splitting training data into train/validation sets."""

cache_dir: Optional[str] = "cache/"
split_fraction: Optional[float] = None
"""If set, fraction of training data to use for training. Must be between 0 and 1.
The remainder will be used for validation. This overrides any existing validation set."""
split_key: PRNGKeyArray = jax.random.PRNGKey(0)

def __post_init__(self):
if self.split_fraction is not None:
if not 0 < self.split_fraction < 1:
raise ValueError(f"split_fraction must be between 0 and 1, got {self.split_fraction}")

if self.split_key is None:
raise ValueError("split_key must be provided when split_fraction is set")

if self._has_validation_set:
logger.warning(
"Dataset has an existing validation set - this will be ignored in favor of the split train set"
)

def train_set(
self,
Expand All @@ -1023,28 +1041,38 @@ def train_set(
key: Optional[PRNGKeyArray] = None,
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:

ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors)

# add epoch flag here.
# Get the dataset and handle None case upfront
ds = self.token_seq_dataset("train", seq_len, monitors)
if ds is None:
raise ValueError("No training set!")

if self.split_fraction is not None:
ds = ds.shuffle(self.split_key) # type: ignore
ds = ds.slice_proportionally(start_fraction=0, end_fraction=self.split_fraction) # type: ignore

if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)
ds = EpochDataset(ds, max_epochs=epochs) # type: ignore

if self.shuffle is True:
ds = ds.shuffle(key)
ds = ds.shuffle(key) # type: ignore
elif isinstance(self.shuffle, int) and self.shuffle > 0:
ds = ds.era_shuffle(self.shuffle, key=key)
ds = ds.era_shuffle(self.shuffle, key=key) # type: ignore

return ds # type: ignore

def validation_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> Optional[TokenSeqDataset]:
return self.token_seq_dataset("validation", seq_len, monitors)
) -> Optional[AsyncDataset[np.ndarray]]:
if self.split_fraction is not None:
ds: Optional[TokenSeqDataset] = self.token_seq_dataset("train", seq_len, monitors)
if ds is None:
return None
ds = ds.shuffle(self.split_key) # Use same key as train set for consistent split
ds = ds.slice_proportionally(start_fraction=self.split_fraction, end_fraction=1.0) # type: ignore
return ds # type: ignore
else:
return self.token_seq_dataset("validation", seq_len, monitors)

def validation_sets(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
Expand Down
187 changes: 187 additions & 0 deletions tests/test_train_val_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import jax
import pytest

from levanter.data import ListAsyncDataset
from levanter.data.text import LMDatasetConfig


@pytest.mark.asyncio
async def test_basic_split():
"""Test basic 80-20 split functionality"""
# Create a simple dataset
data = list(range(100))
ds = ListAsyncDataset(data, is_complete=True)

config = LMDatasetConfig(
split_fraction=0.8,
split_key=jax.random.PRNGKey(0),
)

# Mock the token_seq_dataset method to return our test dataset
config.token_seq_dataset = lambda split, seq_len, monitors: ds

# Get train and validation sets
train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1))
val_ds = config.validation_set(seq_len=1)

# Check lengths
train_indices = list(range(await train_ds.async_len()))
val_indices = list(range(await val_ds.async_len()))

train_len = len(await train_ds.get_batch(train_indices))
val_len = len(await val_ds.get_batch(val_indices))

assert train_len == 80
assert val_len == 20
assert train_len + val_len == len(data)


@pytest.mark.asyncio
async def test_disjoint_split():
"""Test that train and validation sets are disjoint"""
data = list(range(100))
ds = ListAsyncDataset(data, is_complete=True)

config = LMDatasetConfig(
split_fraction=0.8,
split_key=jax.random.PRNGKey(0),
)

config.token_seq_dataset = lambda split, seq_len, monitors: ds

train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1))
val_ds = config.validation_set(seq_len=1)

train_items = set(await train_ds.get_batch(list(range(await train_ds.async_len()))))
val_items = set(await val_ds.get_batch(list(range(await val_ds.async_len()))))

print(train_items)
print(val_items)

# Check sets are disjoint
assert len(train_items.intersection(val_items)) == 0
# Check union covers all data
assert train_items.union(val_items) == set(data)


@pytest.mark.asyncio
async def test_deterministic_split():
"""Test that splits are deterministic with same key"""
data = list(range(100))
ds = ListAsyncDataset(data, is_complete=True)

key = jax.random.PRNGKey(0)

# Create two configs with same key
config1 = LMDatasetConfig(split_fraction=0.8, split_key=key)
config2 = LMDatasetConfig(split_fraction=0.8, split_key=key)

config1.token_seq_dataset = lambda split, seq_len, monitors: ds
config2.token_seq_dataset = lambda split, seq_len, monitors: ds

# Get train sets from both configs
train_ds1 = config1.train_set(seq_len=1, key=jax.random.PRNGKey(1))
train_ds2 = config2.train_set(seq_len=1, key=jax.random.PRNGKey(1))

train_items1 = await train_ds1.get_batch(list(range(await train_ds1.async_len())))
train_items2 = await train_ds2.get_batch(list(range(await train_ds2.async_len())))

assert train_items1 == train_items2


@pytest.mark.asyncio
async def test_different_keys_different_splits():
"""Test that different keys produce different splits"""
data = list(range(100))
ds = ListAsyncDataset(data, is_complete=True)

config1 = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(0))
config2 = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(1))

config1.token_seq_dataset = lambda split, seq_len, monitors: ds
config2.token_seq_dataset = lambda split, seq_len, monitors: ds

train_ds1 = config1.train_set(seq_len=1, key=jax.random.PRNGKey(2))
train_ds2 = config2.train_set(seq_len=1, key=jax.random.PRNGKey(2))

train_items1 = await train_ds1.get_batch(list(range(await train_ds1.async_len())))
train_items2 = await train_ds2.get_batch(list(range(await train_ds2.async_len())))

assert train_items1 != train_items2


@pytest.mark.asyncio
async def test_edge_case_splits():
"""Test edge cases for split fractions"""
data = list(range(100))
ds = ListAsyncDataset(data, is_complete=True)

# Test with very small split
config = LMDatasetConfig(split_fraction=0.01, split_key=jax.random.PRNGKey(0))
config.token_seq_dataset = lambda split, seq_len, monitors: ds

train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1))
val_ds = config.validation_set(seq_len=1)

train_len = len(await train_ds.get_batch(list(range(await train_ds.async_len()))))
val_len = len(await val_ds.get_batch(list(range(await val_ds.async_len()))))

assert train_len == 1
assert val_len == 99

# Test with very large split
config = LMDatasetConfig(split_fraction=0.99, split_key=jax.random.PRNGKey(0))
config.token_seq_dataset = lambda split, seq_len, monitors: ds

train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1))
val_ds = config.validation_set(seq_len=1)

train_len = len(await train_ds.get_batch(list(range(await train_ds.async_len()))))
val_len = len(await val_ds.get_batch(list(range(await val_ds.async_len()))))

assert train_len == 99
assert val_len == 1


def test_invalid_split_fractions():
"""Test that invalid split fractions raise appropriate errors"""
# Test split fraction = 0
with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"):
LMDatasetConfig(split_fraction=0, split_key=jax.random.PRNGKey(0))

# Test split fraction = 1
with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"):
LMDatasetConfig(split_fraction=1, split_key=jax.random.PRNGKey(0))

# Test negative split fraction
with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"):
LMDatasetConfig(split_fraction=-0.1, split_key=jax.random.PRNGKey(0))

# Test split fraction > 1
with pytest.raises(ValueError, match="split_fraction must be between 0 and 1"):
LMDatasetConfig(split_fraction=1.1, split_key=jax.random.PRNGKey(0))


def test_missing_split_key():
"""Test that missing split key raises appropriate error"""
with pytest.raises(ValueError, match="split_key must be provided when split_fraction is set"):
LMDatasetConfig(split_fraction=0.8, split_key=None)


@pytest.mark.asyncio
async def test_empty_dataset():
"""Test splitting an empty dataset"""
data = []
ds = ListAsyncDataset(data, is_complete=True)

config = LMDatasetConfig(split_fraction=0.8, split_key=jax.random.PRNGKey(0))
config.token_seq_dataset = lambda split, seq_len, monitors: ds

train_ds = config.train_set(seq_len=1, key=jax.random.PRNGKey(1))
val_ds = config.validation_set(seq_len=1)

# Empty batch should raise ValueError
with pytest.raises(ValueError, match="Dataset is empty"):
await train_ds.get_batch([])
with pytest.raises(ValueError, match="Dataset is empty"):
await val_ds.get_batch([])

0 comments on commit e59c631

Please sign in to comment.