Skip to content

Multi-Dataset Validation (LM-Loss/Perplexity) #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,9 @@ Save the following as `fast-llm-tutorial/train-config.yaml`:
logs:
interval: 10
validation:
iterations: 25
interval: 100
Validation:
iterations: 25
interval: 100
export: # (2)!
format: llama
interval: 100
Expand Down Expand Up @@ -550,8 +551,9 @@ Save the following as `fast-llm-tutorial/train-config.yaml`:
logs:
interval: 10
validation:
iterations: 25
interval: 1000
Validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
Expand Down
10 changes: 6 additions & 4 deletions docs/recipes/continue-training.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ This is not much different from a pretraining config. We will:
logs:
interval: 10
validation:
iterations: 25
interval: 1000
Validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
Expand Down Expand Up @@ -84,8 +85,9 @@ This is not much different from a pretraining config. We will:
logs:
interval: 10
validation:
iterations: 25
interval: 1000
Validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
Expand Down
39 changes: 38 additions & 1 deletion docs/recipes/data-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ We already saw an example dataset configuration in the [quick-start guide](../qu

In this section we are interested in generalizing step 3. For more details on steps 1 and 2, please refer to the quick-start guide or [this example](data-configuration.md).

The section `data.datasets` holds descriptions of datasets used in training, validation, and testing.

The Training and Testing phases must have predetermined dataset names: `Training` and `Testing`, respectively. Each of these phases can have only one dataset.

For validation datasets, the rules are different. There can be as many validation datasets as needed, and their names are arbitrary. In the example above, the dataset name `Validation` is chosen for simplicity. The datasets names used for validation and their application details are specified in the training config validation sections.

Adding multiple validation datasets increases flexibility in tracking the accuracy of your trained model. One possible scenario is using a separate validation dataset for each blended training dataset, allowing you to track training progress on each subset separately and observe how the model performs in real time on different subsets of your training data.

Below are examples of how to configure various aspects of training and validation datasets.

## Example 1: Blending multiple datasets

In this example, we have three datasets and want to sample from each of them during training with probabilities 0.70, 0.25 and 0.05. For this, we use the `blended` type which takes other datasets as arguments:
Expand Down Expand Up @@ -118,7 +128,34 @@ data:
!!! note "Default seed"
In the absence of explicit seed, Fast-LLM uses a default seed (`data.sampling`'s default) instead, and uses seed shifts to ensure different seeds for each phase and for the various blended datasets.

## Example 5: Advanced scenario

## Example 5: Specifying Multiple Validation Datasets

In this example, we show how to specify multiple validation datasets and configure how often they are applied, along with their application attributes in the `training.validation` section.

Please note that the same dataset names must be used in the `training.validation` section. If a validation dataset is specified in the `datasets` section but not in `training.validation`, it will not be used for validation.

```yaml
training:
validation:
the_stack:
iterations: 25
interval: 50
fineweb:
iterations: 25
interval: 100
data:
datasets:
the_stack:
type: file
path: path/to/validation_the_stack_dataset.yaml
fineweb:
type: file
path: path/to/validation_fineweb_dataset.yaml

```

## Example 6: Advanced scenario

In this example, we combine everything we learned so far to create a complex scenario, where:

Expand Down
10 changes: 6 additions & 4 deletions docs/recipes/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ Let's start from the following training configuration:
logs:
interval: 10
validation:
iterations: 25
interval: 1000
Validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
Expand Down Expand Up @@ -64,8 +65,9 @@ Let's start from the following training configuration:
logs:
interval: 10
validation:
iterations: 25
interval: 1000
Validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
Expand Down
3 changes: 2 additions & 1 deletion examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ training:
logs:
interval: 10
validation:
iterations: null
Validation:
iterations: null
test_iters: 0
batch:
sequence_length: 4096
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_samples_per_phase: dict[PhaseType, int]
_samples_per_dataset: dict[str, int]
_cache_directory: pathlib.Path | None

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
Expand All @@ -24,12 +24,12 @@ def __init__(self, config: DataConfig, distributed_config: DistributedConfig) ->
def setup(
self,
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
samples_per_dataset: dict[str, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._samples_per_phase = samples_per_phase
self._samples_per_dataset = samples_per_dataset
self._cache_directory = cache_directory

@property
Expand All @@ -40,7 +40,7 @@ def distributed(self):
def get_iterator(
self,
batch_config: BatchConfig,
phase: PhaseType,
dataset_name: str,
*,
consumed_samples: int,
num_workers: int,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
hint=FieldHint.feature,
)
# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[PhaseType, GPTSampledDatasetConfig] = Field(
datasets: dict[str, GPTSampledDatasetConfig] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
Expand All @@ -63,7 +63,7 @@ def _validate(self) -> None:
"Using the legacy dataset definition format." " Specify it through `data.datasets` instead."
)
self.datasets = {
phase: GPTLegacyDatasetConfig.from_dict(self, strict=False)
phase.value: GPTLegacyDatasetConfig.from_dict(self, strict=False)
for phase in (PhaseType.training, PhaseType.validation, PhaseType.test)
}
super()._validate()
26 changes: 13 additions & 13 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def __init__(
def setup(
self,
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
samples_per_dataset: dict[str, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
This may take a while and a significant amount of cpu memory.
"""
super().setup(distributed, samples_per_phase, cache_directory)
super().setup(distributed, samples_per_dataset, cache_directory)
log_main_rank(f"Preparing dataset. This may take several minutes.")
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)

Expand All @@ -97,23 +97,23 @@ def setup(
warnings.warn(f"Using the dataset directory for the index cache.")

self._datasets = {}
for phase, num_samples in samples_per_phase.items():
for dataset_name, num_samples in samples_per_dataset.items():
if num_samples > 0:
# TODO: Do the check earlier.
assert phase in self._config.datasets
assert dataset_name in self._config.datasets
sampling = GPTSamplingData(
num_samples=samples_per_phase[phase],
num_samples=samples_per_dataset[dataset_name],
config=self._config.sampling,
cache_directory=self._cache_directory,
distributed=distributed,
phase=phase,
dataset_name=dataset_name,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
cross_document_attention=self._cross_document_attention,
)
dataset = self._config.datasets[phase].build_and_sample(sampling)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True
Expand All @@ -126,21 +126,21 @@ def tokenizer(self) -> Tokenizer:
def get_iterator(
self,
batch_config: BatchConfig,
phase: PhaseType,
dataset_name: str,
*,
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
) -> typing.Iterator[typing.Any]:
assert self._is_setup
Assert.incl(phase, self._datasets)
Assert.incl(dataset_name, self._datasets)
Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length)
log_main_rank(f"Initializing {phase} data iterator from sample {consumed_samples}...")
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")
return iter(
torch.utils.data.DataLoader(
self._datasets[phase], # noqa
self._datasets[dataset_name], # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[phase]),
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SamplingData:
cache_directory: pathlib.Path | None
# TODO: This prevents the sampling config from being pickled in multiprocessing.
distributed: "Distributed"
phase: PhaseType
dataset_name: str
# Using a mutable rather than an int so it's shared with all copies made with `update`.
_rank_counter: typing.Iterator[int] = itertools.count

Expand Down
9 changes: 5 additions & 4 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,12 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
raise NotImplementedError(self.format)

phase_splits = padded_cumsum(normalize_probabilities(self.split))

phase_index = {
PhaseType.training: 0,
PhaseType.validation: 1,
PhaseType.test: 2,
}[sampling.phase]
PhaseType.training.value: 0,
PhaseType.validation.value: 1,
PhaseType.test.value: 2,
}[sampling.dataset_name]

dataset_configs = [
{
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ class ShutdownConfig(IntervalConfig):

@config_class()
class TrainingConfig(Config):
validation: ValidationConfig = Field(
default_factory=ValidationConfig,
desc="Configuration for the validation phase",
validation: dict[str, ValidationConfig] = Field(
default_factory=dict,
desc="A dictionary of validation dataset names and their configurations for the validation phase.",
hint=FieldHint.core,
)
logs: MetricsLogsConfig = Field(
Expand Down
Loading