From 901723f58ba5cbd8201a3fbbcc21550857159ff9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Jun 2025 12:01:04 -0700 Subject: [PATCH 01/48] rfc --- torchtune/datasets/rfc_iterable_dataset.md | 404 +++++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 torchtune/datasets/rfc_iterable_dataset.md diff --git a/torchtune/datasets/rfc_iterable_dataset.md b/torchtune/datasets/rfc_iterable_dataset.md new file mode 100644 index 0000000000..a8efd553e6 --- /dev/null +++ b/torchtune/datasets/rfc_iterable_dataset.md @@ -0,0 +1,404 @@ +### Core issues: + 1) No support for iterative dataset: + - Dataset has to be fully loaded in memory; + - With map-style, no control over multi-sample operations, e.g. packing or skipping + - map-style is slower + - no support for streaming + 2) No support for weighted dataset: + - We have it in a single newly added dev recipe/config, but API needs polishing; + - We also support ConcatDataset, but its map style and there is no weighting; + 3) No support for on-the-fly data packing: It's done before training, taking a long time for large datasets; + +### UX issues: + 4) Unclear boundaries between HF and torchtune args + + ```python + def alpaca_dataset( + # --- torchtune specific args --- + tokenizer: ModelTokenizer, + train_on_input: bool = True, + packed: bool = False, + # --- HF loading args --- + source: str = "tatsu-lab/alpaca", + column_map: Optional[Dict[str, str]] = None, + split: str = "train", + **load_dataset_kwargs: Dict[str, Any], + # --- HF dataset method --- + filter_fn: Optional[Callable] = None, + ) -> Union[SFTDataset, PackedDataset]: + ``` + + 5) Lack of dataloader args: args are scattered in the config. Important args are not exposed, e.g. num_workers, pin_memory, etc. + ```yaml + dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + seed: null + batch_size: 8 + shuffle: True + collate_fn: torchtune.data.padded_collate_tiled_images_and_mask + ``` + + 6) Different datasets have different arguments, because their message transforms are different. + +### Principles: + - Common API signatures for all datasets + - Offload what we can to hf datasets methods directly + - Less protagonism from our functions. E.g. config manipulations, instantiation, etc. (not the focus of this diff) + +### Proposal: + +# config.yaml +```yaml + +########### +# tokenizer +########### +tokenizer: + _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform + path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model + image_size: 560 + max_seq_len: 8192 + +########## +# dataloader +# consolidate all dataloader args here, which are currently scattered +########## +dataloader: + _component_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 4 + num_workers: 4 + pin_memory: true + collate_fn: torchtune.data.padded_collate + + +#-------------------------------- +######### +# dataset if the class is used directly, as in our current SFTDataset +######### +dataset: + - _component_: torchtune.datasets.HfIterableDataset + load_args: + path: "tatsu-lab/alpaca" + split: "train" + message_transform: + _component_: torchtune.datasets.alpaca_message_transform + masking_strategy: "output_only" + column_map: + input: "prompt" + output: "response" + system_prompt: "foo" + filter_args: + function: torchtune.datasets.filter_fn_even_indices + with_indices: True + weight: 0.8 + - _component_: torchtune.datasets.HfIterableDataset + load_args: + path: "tatsu-lab/gsm8k" + split: "train" + message_transform: + _component_: torchtune.datasets.gsm8k_message_transform + masking_strategy: "output_only" + column_map: + input: "prompt" + output: "response" + system_prompt: "bar" + weight: 0.2 + +######### +# OR with builders +# TODO: test indexing "`tune run config – dataset[0].load_arg.split=train`" +######### +dataset: + - _component_: torchtune.datasets.build_alpaca_dataset + load_args: + split: "valid" + weight: 0.8 + - _component_: torchtune.datasets.build_gsm8k_dataset + message_transform: + system_prompt: "bar" + weight: 0.2 + +######### +# OR for a single dataset +######### +dataset: + _component_: torchtune.datasets.build_alpaca_dataset +#-------------------------------- + +######### +# Place for args common for all datasets that will be passed to the dataset constructor +# useful for multidataset. Used as cfg = dataset_defaults.update(dataset_cfg) +######### +dataset_defaults: + shuffle_buffer_size: 1000 + num_shards_per_worker: 16 + seed: ${seed} + tokenizer: ${tokenizer} + recipe_transform: + _component_: torchtune.datasets.SFTTransform + +######### +# args used in the dataset setup. This is not dataset specific. +######### +dataset_setup: + packing: + _component_: torchtune.datasets.packing.SFTPacking + max_seq_len: ${tokenizer.max_seq_len} + multidataset_stopping_strategy: "first_exhausted" # "all_exhausted" +``` + +# Builder example: torchtune/datasets/alpaca_dataset.py + +```python +def alpaca_dataset( + *, + load_args: optional[dict], + message_transform: optional[callable|dict], + tokenizer: ModelTokenizer, + recipe_transform: callable, + *args, **kwargs + ): + _load_args = { + source: "tatsu-lab/alpaca", + split: str = "train" + } + _message_transform_args = { + "train_on_input":False, + "column_map"={"input": "prompt", "output": "response"} + } + + # unify args + if load_args: + _load_args.update(**load_args) + + # unify args + if not message_transform or isinstance(message_transform, dict): + # remove component key, since we are using alpaca_message_transform as default + message_transform.pop("_component_", None) + + # instantiate the message transform + _message_transform_args.update(message_transform) + message_transform = alpaca_message_transform(**_message_transform_args) + + return HfIterableDataset(load_args, message_transform, tokenizer, recipe_transform, *args, **kwargs) +``` + +# Iterable dataset: Shared for all datasets and recipes (SFT, DPO, etc). Differences are in the transforms. +Location: torchtune/datasets/hf_iterable_dataset.py + +```python +class HfIterableDataset(IterableDataset, Stateful): + def __init__( + self, + *. + load_args: Dict, + message_transform: Callable, + tokenizer: Callable, + recipe_transform: Callable, + shuffle_buffer_size:Optional[int] = 1000, + seed:Optional[int] = 42 + num_shards_per_worker: int = 16, + weight:float = 1.0, + filter_args: Optional[Dict] = None, + *args, **kwargs + ): + """Initialize a single dataset with its specific transformations.""" + self.weight = weight + + world_size = 1 + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + + #TODO: Maybe # shards should be based on dataset size, if we know it + num_shards = world_size * num_shards_per_worker + ds = load_dataset(**load_args) + ds = ds.to_iterable_dataset(num_shards) + + if filter_args: + function = filter_args.get("function", None) + if function and not isinstance(function, Callable): + raise ValueError(f"filter_args['function'] must be a callable. Found {type(function)}") + # https://huggingface.co/docs/datasets/v3.6.0/en/stream#filter + ds = ds.filter(**filter_args) + + def _apply_transforms(sample): + sample = message_transform(sample) + sample = tokenizer(sample) + return recipe_transform(sample) + + ds = ds.map(_apply_transforms) #lazy + + if shuffle_buffer_size and shuffle_buffer_size > 0: + ds = ds.shuffle(shuffle_buffer_size, seed) + + # distribute + if world_size>1: + ds = split_dataset_by_node( + ds, + rank=torch.distributed.get_rank(), + world_size=world_size, + ) + + self.ds = ds + + def __iter__(self): + # Expose the for loop so extra logic can be added here, e.g. drop if no trainable tokens + # TODO: should we add try/except to handle/logerrors? + for sample in self.ds: + yield sample + + def state_dict(self): + state_dict = self.ds.state_dict() + state_dict["weight"] = self.weight + return state_dict + + def load_state_dict(self, state_dict): + self.weight = state_dict.pop("weight") + self.ds.load_state_dict(state_dict) +``` + +# Setup Data +Method in recipes/full_distributed.py +OR utility used in the recipe + +```python +from datasets import interleave_datasets, split_dataset_by_node +from torchtune.models.tokenizers import ModelTokenizer +import torch + +#NOTE: I have mixed feelings about passing multiple configDict to setup_data. This feels hard for the user to know what they should contain. On the other hand, i) setup_data doesnt need to make assumptions about the configs ii) we already do it currently. Alternative: use dataclassses + +def setup_data( + dataset_cfg: ConfigDict, + dataset_defaults: ConfigDict, + data_setup_cfg: ConfigDict, + dataloader_cfg: ConfigDict, + seed: int, + pad_idx: int, + ignore_idx: int, + pad_to_multiple_of: int, + ) -> "IterableDataset": + """ + Equivalent to setup_data in the recipe + """ + iterable_datasets = [] + weights = [] + dataset_defaults = {} if dataset_defaults is None else dataset_defaults + + # add to a list just for processing + if not isinstance(dataset_cfg, list): + dataset_cfg = [dataset_cfg] + + for base_cfg in dataset_cfg: + weight = base_cfg.get("weight", 1.0) + weights.append(weight) + + base_cfg = OmegaConf.merge(dataset_defaults, base_cfg) + ds = instantiate(base_cfg) + iterable_datasets.append(ds) + + + # interleave for multidataset + if len(iterable_datasets) > 1: + weights = normalize_weights(weights) # sum to 1 + ds = interleave_datasets( + iterable_datasets, + probabilities=weights, + seed=seed, + # strategies: https://huggingface.co/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy + stopping_strategy=data_setup_cfg.multidataset_stopping_strategy, + ) + else: + ds = iterable_datasets[0] + + # FIXME: remove from config + if setup_cfg.packing: + # Subclass of IterableDataset, takes any iterator as input + ds = instantiate(data_setup_cfg.packing, + dataset=ds, + padding_idx=pad_id, #TODO: in the future, move padding to collate_fn + ) + + # Instantiate collate_fn + collate_fn = dataloader_cfg.pop("collate_fn", None) + #TODO: in the future, unify those two + if collate_fn is None: + collate_fn = "torchtune.data.padded_collate_packed" if packing else "torchtune.data.padded_collate_sft" + + collate_fn = _get_component_from_path(collate_fn) + collate_fn = partial(collate_fn, + padding_idx=pad_idx, + ignore_idx=ignore_id, + pad_to_multiple_of=pad_to_multiple_of + ) + + # dropping last avoids shape issues with compile + flex attention + if "drop_last" not in dataloader_cfg: + dataloader_cfg["drop_last"] = True + + dataloader = instantiate(dataloader_cfg, dataset=ds, collate_fn=collate_fn) + + return dataloader +``` + +# Recipe train loop +```python +for epoch in range(n_epochs): + my_iterable_dataset.set_epoch(epoch) + for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch` + pass +``` + +### Backward compatibility + +Options: +1. Make setup_data an utility, and have two utilities supporting the old and new config formats. After deprecation period, old utility is removed. + +Pros: modularize it and remove from the recipe. Future changes will be easier to implement. +Cons: Big change in how we handle recipe utilities. + +2. Create an adapter migrate_old_to_new_config: +Pros: Recipes still have method _setup_data exposing the logic +Cons: Hard to debug the migrated configs, edge cases not covered by the adapter, ConcatDataset is handled differently. + +3. No migration. Old config with old recipe will break. Users need to update +their configs. No idea how this affects llamastack / startups / others. + +**Implementation of option 1 (Make setup_data an utility)** + +# torchtune/training/data_utils.py or similar location +```python +@deprecated +def is_legacy_data_config(cfg: DictConfig) -> bool: + """ + Detect if config follows legacy format vs new iterable dataset format. + """ + # Check for new format indicators first + has_dataloader_section = "dataloader" in cfg + has_dataset_defaults = "dataset_defaults" in cfg + has_dataset_setup = "dataset_setup" in cfg + + return not (has_dataloader_section or has_dataset_defaults or has_dataset_setup) + +@deprecated +def setup_data_legacy( + ... +) -> StatefulDataLoader: + """ + Legacy data setup function to maintain backward compatibility. + This replicates the current behavior in full_finetune_distributed.py + """ + # same as current setup_data in the recipe.... + + return dataloader +``` + +In the recipe: +```python +def _setup(...): + ... + if is_legacy_data_config(cfg): + dataloader = setup_data_legacy(...) + else: + dataloader = setup_data(...) +``` From ab02d75ad65e8096787a8ec544fb7c5c90b2b61b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Jun 2025 12:03:01 -0700 Subject: [PATCH 02/48] Revert "rfc" This reverts commit 901723f58ba5cbd8201a3fbbcc21550857159ff9. --- torchtune/datasets/rfc_iterable_dataset.md | 404 --------------------- 1 file changed, 404 deletions(-) delete mode 100644 torchtune/datasets/rfc_iterable_dataset.md diff --git a/torchtune/datasets/rfc_iterable_dataset.md b/torchtune/datasets/rfc_iterable_dataset.md deleted file mode 100644 index a8efd553e6..0000000000 --- a/torchtune/datasets/rfc_iterable_dataset.md +++ /dev/null @@ -1,404 +0,0 @@ -### Core issues: - 1) No support for iterative dataset: - - Dataset has to be fully loaded in memory; - - With map-style, no control over multi-sample operations, e.g. packing or skipping - - map-style is slower - - no support for streaming - 2) No support for weighted dataset: - - We have it in a single newly added dev recipe/config, but API needs polishing; - - We also support ConcatDataset, but its map style and there is no weighting; - 3) No support for on-the-fly data packing: It's done before training, taking a long time for large datasets; - -### UX issues: - 4) Unclear boundaries between HF and torchtune args - - ```python - def alpaca_dataset( - # --- torchtune specific args --- - tokenizer: ModelTokenizer, - train_on_input: bool = True, - packed: bool = False, - # --- HF loading args --- - source: str = "tatsu-lab/alpaca", - column_map: Optional[Dict[str, str]] = None, - split: str = "train", - **load_dataset_kwargs: Dict[str, Any], - # --- HF dataset method --- - filter_fn: Optional[Callable] = None, - ) -> Union[SFTDataset, PackedDataset]: - ``` - - 5) Lack of dataloader args: args are scattered in the config. Important args are not exposed, e.g. num_workers, pin_memory, etc. - ```yaml - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset - seed: null - batch_size: 8 - shuffle: True - collate_fn: torchtune.data.padded_collate_tiled_images_and_mask - ``` - - 6) Different datasets have different arguments, because their message transforms are different. - -### Principles: - - Common API signatures for all datasets - - Offload what we can to hf datasets methods directly - - Less protagonism from our functions. E.g. config manipulations, instantiation, etc. (not the focus of this diff) - -### Proposal: - -# config.yaml -```yaml - -########### -# tokenizer -########### -tokenizer: - _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform - path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model - image_size: 560 - max_seq_len: 8192 - -########## -# dataloader -# consolidate all dataloader args here, which are currently scattered -########## -dataloader: - _component_: torchdata.stateful_dataloader.StatefulDataLoader - batch_size: 4 - num_workers: 4 - pin_memory: true - collate_fn: torchtune.data.padded_collate - - -#-------------------------------- -######### -# dataset if the class is used directly, as in our current SFTDataset -######### -dataset: - - _component_: torchtune.datasets.HfIterableDataset - load_args: - path: "tatsu-lab/alpaca" - split: "train" - message_transform: - _component_: torchtune.datasets.alpaca_message_transform - masking_strategy: "output_only" - column_map: - input: "prompt" - output: "response" - system_prompt: "foo" - filter_args: - function: torchtune.datasets.filter_fn_even_indices - with_indices: True - weight: 0.8 - - _component_: torchtune.datasets.HfIterableDataset - load_args: - path: "tatsu-lab/gsm8k" - split: "train" - message_transform: - _component_: torchtune.datasets.gsm8k_message_transform - masking_strategy: "output_only" - column_map: - input: "prompt" - output: "response" - system_prompt: "bar" - weight: 0.2 - -######### -# OR with builders -# TODO: test indexing "`tune run config – dataset[0].load_arg.split=train`" -######### -dataset: - - _component_: torchtune.datasets.build_alpaca_dataset - load_args: - split: "valid" - weight: 0.8 - - _component_: torchtune.datasets.build_gsm8k_dataset - message_transform: - system_prompt: "bar" - weight: 0.2 - -######### -# OR for a single dataset -######### -dataset: - _component_: torchtune.datasets.build_alpaca_dataset -#-------------------------------- - -######### -# Place for args common for all datasets that will be passed to the dataset constructor -# useful for multidataset. Used as cfg = dataset_defaults.update(dataset_cfg) -######### -dataset_defaults: - shuffle_buffer_size: 1000 - num_shards_per_worker: 16 - seed: ${seed} - tokenizer: ${tokenizer} - recipe_transform: - _component_: torchtune.datasets.SFTTransform - -######### -# args used in the dataset setup. This is not dataset specific. -######### -dataset_setup: - packing: - _component_: torchtune.datasets.packing.SFTPacking - max_seq_len: ${tokenizer.max_seq_len} - multidataset_stopping_strategy: "first_exhausted" # "all_exhausted" -``` - -# Builder example: torchtune/datasets/alpaca_dataset.py - -```python -def alpaca_dataset( - *, - load_args: optional[dict], - message_transform: optional[callable|dict], - tokenizer: ModelTokenizer, - recipe_transform: callable, - *args, **kwargs - ): - _load_args = { - source: "tatsu-lab/alpaca", - split: str = "train" - } - _message_transform_args = { - "train_on_input":False, - "column_map"={"input": "prompt", "output": "response"} - } - - # unify args - if load_args: - _load_args.update(**load_args) - - # unify args - if not message_transform or isinstance(message_transform, dict): - # remove component key, since we are using alpaca_message_transform as default - message_transform.pop("_component_", None) - - # instantiate the message transform - _message_transform_args.update(message_transform) - message_transform = alpaca_message_transform(**_message_transform_args) - - return HfIterableDataset(load_args, message_transform, tokenizer, recipe_transform, *args, **kwargs) -``` - -# Iterable dataset: Shared for all datasets and recipes (SFT, DPO, etc). Differences are in the transforms. -Location: torchtune/datasets/hf_iterable_dataset.py - -```python -class HfIterableDataset(IterableDataset, Stateful): - def __init__( - self, - *. - load_args: Dict, - message_transform: Callable, - tokenizer: Callable, - recipe_transform: Callable, - shuffle_buffer_size:Optional[int] = 1000, - seed:Optional[int] = 42 - num_shards_per_worker: int = 16, - weight:float = 1.0, - filter_args: Optional[Dict] = None, - *args, **kwargs - ): - """Initialize a single dataset with its specific transformations.""" - self.weight = weight - - world_size = 1 - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - - #TODO: Maybe # shards should be based on dataset size, if we know it - num_shards = world_size * num_shards_per_worker - ds = load_dataset(**load_args) - ds = ds.to_iterable_dataset(num_shards) - - if filter_args: - function = filter_args.get("function", None) - if function and not isinstance(function, Callable): - raise ValueError(f"filter_args['function'] must be a callable. Found {type(function)}") - # https://huggingface.co/docs/datasets/v3.6.0/en/stream#filter - ds = ds.filter(**filter_args) - - def _apply_transforms(sample): - sample = message_transform(sample) - sample = tokenizer(sample) - return recipe_transform(sample) - - ds = ds.map(_apply_transforms) #lazy - - if shuffle_buffer_size and shuffle_buffer_size > 0: - ds = ds.shuffle(shuffle_buffer_size, seed) - - # distribute - if world_size>1: - ds = split_dataset_by_node( - ds, - rank=torch.distributed.get_rank(), - world_size=world_size, - ) - - self.ds = ds - - def __iter__(self): - # Expose the for loop so extra logic can be added here, e.g. drop if no trainable tokens - # TODO: should we add try/except to handle/logerrors? - for sample in self.ds: - yield sample - - def state_dict(self): - state_dict = self.ds.state_dict() - state_dict["weight"] = self.weight - return state_dict - - def load_state_dict(self, state_dict): - self.weight = state_dict.pop("weight") - self.ds.load_state_dict(state_dict) -``` - -# Setup Data -Method in recipes/full_distributed.py -OR utility used in the recipe - -```python -from datasets import interleave_datasets, split_dataset_by_node -from torchtune.models.tokenizers import ModelTokenizer -import torch - -#NOTE: I have mixed feelings about passing multiple configDict to setup_data. This feels hard for the user to know what they should contain. On the other hand, i) setup_data doesnt need to make assumptions about the configs ii) we already do it currently. Alternative: use dataclassses - -def setup_data( - dataset_cfg: ConfigDict, - dataset_defaults: ConfigDict, - data_setup_cfg: ConfigDict, - dataloader_cfg: ConfigDict, - seed: int, - pad_idx: int, - ignore_idx: int, - pad_to_multiple_of: int, - ) -> "IterableDataset": - """ - Equivalent to setup_data in the recipe - """ - iterable_datasets = [] - weights = [] - dataset_defaults = {} if dataset_defaults is None else dataset_defaults - - # add to a list just for processing - if not isinstance(dataset_cfg, list): - dataset_cfg = [dataset_cfg] - - for base_cfg in dataset_cfg: - weight = base_cfg.get("weight", 1.0) - weights.append(weight) - - base_cfg = OmegaConf.merge(dataset_defaults, base_cfg) - ds = instantiate(base_cfg) - iterable_datasets.append(ds) - - - # interleave for multidataset - if len(iterable_datasets) > 1: - weights = normalize_weights(weights) # sum to 1 - ds = interleave_datasets( - iterable_datasets, - probabilities=weights, - seed=seed, - # strategies: https://huggingface.co/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy - stopping_strategy=data_setup_cfg.multidataset_stopping_strategy, - ) - else: - ds = iterable_datasets[0] - - # FIXME: remove from config - if setup_cfg.packing: - # Subclass of IterableDataset, takes any iterator as input - ds = instantiate(data_setup_cfg.packing, - dataset=ds, - padding_idx=pad_id, #TODO: in the future, move padding to collate_fn - ) - - # Instantiate collate_fn - collate_fn = dataloader_cfg.pop("collate_fn", None) - #TODO: in the future, unify those two - if collate_fn is None: - collate_fn = "torchtune.data.padded_collate_packed" if packing else "torchtune.data.padded_collate_sft" - - collate_fn = _get_component_from_path(collate_fn) - collate_fn = partial(collate_fn, - padding_idx=pad_idx, - ignore_idx=ignore_id, - pad_to_multiple_of=pad_to_multiple_of - ) - - # dropping last avoids shape issues with compile + flex attention - if "drop_last" not in dataloader_cfg: - dataloader_cfg["drop_last"] = True - - dataloader = instantiate(dataloader_cfg, dataset=ds, collate_fn=collate_fn) - - return dataloader -``` - -# Recipe train loop -```python -for epoch in range(n_epochs): - my_iterable_dataset.set_epoch(epoch) - for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch` - pass -``` - -### Backward compatibility - -Options: -1. Make setup_data an utility, and have two utilities supporting the old and new config formats. After deprecation period, old utility is removed. - -Pros: modularize it and remove from the recipe. Future changes will be easier to implement. -Cons: Big change in how we handle recipe utilities. - -2. Create an adapter migrate_old_to_new_config: -Pros: Recipes still have method _setup_data exposing the logic -Cons: Hard to debug the migrated configs, edge cases not covered by the adapter, ConcatDataset is handled differently. - -3. No migration. Old config with old recipe will break. Users need to update -their configs. No idea how this affects llamastack / startups / others. - -**Implementation of option 1 (Make setup_data an utility)** - -# torchtune/training/data_utils.py or similar location -```python -@deprecated -def is_legacy_data_config(cfg: DictConfig) -> bool: - """ - Detect if config follows legacy format vs new iterable dataset format. - """ - # Check for new format indicators first - has_dataloader_section = "dataloader" in cfg - has_dataset_defaults = "dataset_defaults" in cfg - has_dataset_setup = "dataset_setup" in cfg - - return not (has_dataloader_section or has_dataset_defaults or has_dataset_setup) - -@deprecated -def setup_data_legacy( - ... -) -> StatefulDataLoader: - """ - Legacy data setup function to maintain backward compatibility. - This replicates the current behavior in full_finetune_distributed.py - """ - # same as current setup_data in the recipe.... - - return dataloader -``` - -In the recipe: -```python -def _setup(...): - ... - if is_legacy_data_config(cfg): - dataloader = setup_data_legacy(...) - else: - dataloader = setup_data(...) -``` From 2a2efa298ff4b91308e918d08cc2b2f9a958bef9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 14:23:00 -0400 Subject: [PATCH 03/48] add packed functions --- torchtune/data/_collate.py | 28 +- torchtune/datasets/_iterable_packed.py | 440 +++++++++++++++++++++++++ 2 files changed, 467 insertions(+), 1 deletion(-) create mode 100644 torchtune/datasets/_iterable_packed.py diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 33aec33dde..015eb2761c 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List, Dict import torch import torch.nn.functional as F @@ -12,6 +12,32 @@ from torchtune.modules.attention_utils import packed_block_causal_mask +def collate_packed( + batch: list[dict[str, torch.Tensor]], mask_fn: callable, device: str +) -> dict[str, torch.Tensor]: + """ + Generic collate function for packed samples from an IterablePackedDataset. + + This function handles tensor stacking and delegates attention mask creation + to a provided `mask_fn`. + """ + if not batch: + return {} + + # Assumes all samples in the batch have the same keys, which are all tensors. + keys_to_stack = batch[0].keys() + collated = { + key: torch.stack([sample[key] for sample in batch], dim=0) + for key in keys_to_stack + } + + # Delegate mask creation to the provided specialized function + # TODO: investigate the need for device here. Currently we hardcode it in utilities to cuda. + # shouldnt we just send to device later? + collated["mask"] = mask_fn(collated["document_ids"], device=device) + + return collated + def left_pad_sequence( sequences: list[torch.Tensor], batch_first: bool = False, diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py new file mode 100644 index 0000000000..ba535ee6ba --- /dev/null +++ b/torchtune/datasets/_iterable_packed.py @@ -0,0 +1,440 @@ + +import logging +from abc import ABC, abstractmethod +from collections import deque +from functools import partial +from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, TypeVar + +import torch +from torch.nn.attention.flex_attention import ( + create_block_mask as create_block_mask_flex, +) +from torch.utils.data import IterableDataset, Stateful +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + + +logger = logging.getLogger(__name__) + +SampleType = TypeVar("SampleType") +PackType = Dict[str, torch.Tensor] + + +class PackingStrategy(ABC, Generic[SampleType]): + def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): + """ + Initializes a strategy to be used in IterablePackedDataset. All strategies are meant to be used + with FlexAttention by leveraging 'mask_mod' to create the attention mask. + """ + if not _SUPPORTS_FLEX_ATTENTION: + raise RuntimeError( + "The IterablePackedDataset and its strategies require Flex Attention support, " + "which is not available in the current environment." + ) + self.padding_idx = padding_idx + self.ignore_idx = ignore_idx + + @abstractmethod + def create_empty_pack(self) -> Dict[str, List[Any]]: + """ + Creates an empty pack. + Example: + self.create_empty_pack() + >>> {"tokens": [], "labels": [], "document_ids": [], "input_pos": []} + """ + pass + + @abstractmethod + def get_sample_size(self, sample: SampleType) -> int: + """ + Returns the size of a sample. + Example: + # for a sample with 100 tokens + self.get_sample_size(sample) + >>> 100 + """ + pass + + @abstractmethod + def add_sample_to_pack( + self, pack: Dict[str, List[Any]], sample: SampleType, next_doc_id: int + ) -> int: + """ + Adds a sample to the pack dictionary in-place. + + Args: + pack (Dict[str, List[Any]]): The dictionary representing the pack, to be modified in-place. + sample (SampleType): The sample to add. + next_doc_id (int): The starting document ID to use for this sample. + + Returns: + int: The number of new documents that were added to the pack. + + Example: + pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} + sample = {"tokens": [5, 6], "labels": [7, 8], "document_ids": [1, 1], "input_pos": [0, 1]} + added_docs = self.add_sample_to_pack(pack, sample, next_doc_id=1) + print(pack) + >>> {"tokens": [1, 2, 5, 6], + "labels": [3, 4, 7, 8], + "document_ids": [0, 0, 1, 1], + "input_pos": [0, 1, 0, 1]} + print(added_docs) + >>> 1 + """ + pass + + @abstractmethod + def finalize_pack( + self, pack: Dict[str, List[Any]], target_tokens_per_pack: int, next_doc_id: int + ) -> PackType: + """ + Finalizes a pack, primarily by padding it to the target length. + + Args: + pack (Dict[str, List[Any]]): The pack data. + target_tokens_per_pack (int): The target length to pad to. + next_doc_id (int): The document ID to use for the padding tokens. + + Returns: + PackType: The finalized pack. + + Example: + pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} + target_tokens_per_pack = 4 + next_doc_id = 1 + self.padding_idx = 999 + self.ignore_idx = -100 + + self.finalize_pack(pack, target_tokens_per_pack, next_doc_id) + >>> {"tokens": [1, 2, 999, 999], + "labels": [3, 4, -100, -100], + "document_ids": [0, 0, 1, 1], + "input_pos": [0, 1, 0, 1]} + """ + pass + + @abstractmethod + def _mask_mod( + self, + b: int, + h: int, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + doc_ids: torch.Tensor, + ) -> torch.Tensor: + """ + The core logic for the block attention mask, to be passed to + `torch.nn.attention.flex_attention.create_block_mask`. + + This method is implemented by each strategy to define the specific + attention pattern (e.g., standard causal, DPO, etc.). + + Args: + b (int): Batch index. + h (int): Head index. + q_idx (Tensor): Query indices. + kv_idx (Tensor): Key/value indices. + doc_ids (Tensor): The complete document ID tensor for the batch, + of shape (batch_size, seq_len). + + Returns: + A boolean tensor indicating which query/key pairs are allowed to attend. + """ + pass + + def create_block_mask(self, batch_document_ids, device): + """ + Creates a block-causal attention mask using FlexAttention. + """ + batch_size, seq_len = batch_document_ids.shape + doc_ids = batch_document_ids.to(device) + + # This wrapper is needed so we can unit-test the core `mask_mod` logic + # while still conforming to the function signature required by `create_block_mask_flex`. + def _mask_mod_for_flex(b, h, q_idx, kv_idx): + return self._mask_mod(b, h, q_idx, kv_idx, doc_ids) + + return create_block_mask_flex( + _mask_mod_for_flex, batch_size, None, seq_len, seq_len, device=device + ) + + +class IterablePackedDataset(IterableDataset[PackType], Stateful, Generic[SampleType]): + def __init__( + self, + dataset: IterableDataset[SampleType], + strategy: PackingStrategy[SampleType], + target_tokens_per_pack: int, + buffer_size: int = 50, + ): + """ + IterablePackedDataset takes any IterableDataset and a PackingStrategy, packs documents until + the 'target_tokens_per_pack' is reached and yields a dictionary of tensors. + + Args: + dataset (IterableDataset[SampleType]): The IterableDataset to pack. + strategy (PackingStrategy[SampleType]): The PackingStrategy to use for packing. + target_tokens_per_pack (int): The target number of tokens per pack. + buffer_size (int): The size of the buffer to use for packing. + """ + self.dataset = dataset + self.strategy = strategy + self.target_tokens_per_pack = target_tokens_per_pack + self.buffer_size = buffer_size + + self._reset_packer_state() + + def _reset_packer_state(self) -> None: + """Resets the packer's internal state for a new or resumed iteration.""" + # buffer: deque of (sample, size) tuples that have not been added to a pack yet + if not hasattr(self, "_buffer"): + self._buffer: deque[tuple[SampleType, int]] = deque() + else: + self._buffer.clear() + + # current_pack: the current pack being built + self._current_pack: Optional[dict[str, list]] = None + + # current_pack_size: the number of tokens in the current pack + self._current_pack_size: int = 0 + + # iterator: the iterator over the dataset + self._iterator: Optional[Iterator[SampleType]] = None + + # current_doc_id_in_pack: the document ID to use for the next sample + self._current_doc_id_in_pack: int = 0 + + # exhausted: whether the dataset is exhausted + self._exhausted: bool = False + + # resuming: whether the packer is resuming from a checkpoint + self._resuming: bool = False + + def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: + """ + Fills the buffer with samples from the dataset. + The buffer is a deque of (sample, size) tuples that have not been added to a pack yet. + """ + # Fill buffer until it's full or the dataset is exhausted + while len(self._buffer) < self.buffer_size and not self._exhausted: + try: + sample = next(iterator) + sample_size = self.strategy.get_sample_size(sample) + + # Drop samples that are too large + if sample_size > self.target_tokens_per_pack: + logger.warning( + f"Dropping sample with size {sample_size} > target_tokens_per_pack {self.target_tokens_per_pack}." + ) + else: + self._buffer.append((sample, sample_size)) + except StopIteration: + self._exhausted = True + + def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]: + """ + Find the first sample in the buffer that fits in the remaining space. + + Args: + remaining_size (int): The remaining space in the current pack. + + Returns: + Optional[int]: The index of the sample in the buffer, or None if no sample fits. + + Example: + self._buffer = deque([(sample1, 200), (sample2, 100), (sample3, 48), (sample4, 200)]) + + # First iteration: + selected_sample_idx = self._find_next_fitting_sample(remaining_size=150) # returns 1 + del self._buffer[selected_sample_idx] + + # Second iteration: + selected_sample_idx = self._find_next_fitting_sample(remaining_size=50) # returns 1 + del self._buffer[selected_sample_idx] + + # Third iteration: + selected_sample_idx = self._find_next_fitting_sample(remaining_size=2) # returns None + """ + # Find the first sample in the buffer that fits in the remaining space + for i, (_, sample_size) in enumerate(self._buffer): + if sample_size <= remaining_size: + return i + return None + + def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[PackType]: + """ + Builds a pack of samples from the buffer. + + Args: + iterator (Iterator[SampleType]): The iterator over the dataset. + + Returns: + Optional[PackType]: The pack of samples, or None if the dataset is exhausted. + """ + # Start a new pack if necessary + if self._current_pack is None: + self._current_pack = self.strategy.create_empty_pack() + self._current_pack_size = 0 + self._current_doc_id_in_pack = 0 + + # Fill the current pack until it's full + while self._current_pack_size < self.target_tokens_per_pack: + self._fill_buffer(iterator) + remaining_size = self.target_tokens_per_pack - self._current_pack_size + selected_sample_idx = self._find_next_fitting_sample(remaining_size) + + # If a fitting sample is found, del from buffer and add to the pack + if selected_sample_idx is not None: + sample, sample_size = self._buffer[selected_sample_idx] + del self._buffer[selected_sample_idx] + docs_consumed = self.strategy.add_sample_to_pack( + self._current_pack, sample, self._current_doc_id_in_pack + ) + self._current_doc_id_in_pack += docs_consumed + self._current_pack_size += sample_size + else: + # No fitting sample found, so break to finalize the pack + break + + # If the pack has any content, finalize and return it + if self._current_pack_size > 0: + final_pack = self.strategy.finalize_pack( + self._current_pack, + self.target_tokens_per_pack, + self._current_doc_id_in_pack, + ) + self._current_pack = None + self._current_pack_size = 0 + return final_pack + + if self._exhausted and not self._buffer: + return None + + return None + + def __iter__(self) -> Iterator[PackType]: + if not isinstance(self.dataset, Iterable): + raise TypeError("Dataset is not iterable.") + + if not self._resuming: + self._reset_packer_state() + self._iterator = iter(self.dataset) + + # If resuming, the iterator must be recreated from the loaded state + if self._iterator is None: + self._iterator = iter(self.dataset) + + self._resuming = False # Consume the resume flag + + # Main packing loop + while True: + + # Stop if the source is exhausted and there's no data left to pack + if self._exhausted and not self._buffer and self._current_pack_size == 0: + break + + pack = self._build_one_pack(self._iterator) + if pack: + yield pack + + # If build_one_pack returns None but we are not done, continue loop + # to attempt building another pack (e.g. after buffer is refilled). + elif self._exhausted and not self._buffer: + break + + def state_dict(self) -> Dict[str, Any]: + """ + Get the state of the packer. It relies on the input dataset to save the progress of iteration. + It does NOT save the internal buffer or any partially built pack. + """ + state = {} + if isinstance(self.dataset, Stateful): + state["dataset_state"] = self.dataset.state_dict() + else: + raise ValueError("Dataset is not stateful.") + + return state + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the packer. This restores the state of the underlying dataset. + The buffer and any partially-built pack are discarded. + """ + if isinstance(self.dataset, Stateful) and "dataset_state" in state_dict: + self.dataset.load_state_dict(state_dict["dataset_state"]) + else: + raise ValueError("Dataset is not stateful.") + + self._reset_packer_state() + self._resuming = True + +class TextPackingStrategy(PackingStrategy[Dict[str, List[int]]]): + """ + Strategy for packing standard text samples for causal language modeling. It is designed + to be used with the IterablePackedDataset. + - Each sample is treated as a separate document. + - `input_pos` restarts from 0 for each sample. + - `document_ids` assigns a unique ID to each sample for masking. + """ + + def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): + super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + + def create_empty_pack(self) -> Dict[str, List[int]]: + return { + "tokens": [], + "labels": [], + "document_ids": [], + "input_pos": [], + } + + def get_sample_size(self, sample: Dict[str, List[int]]) -> int: + return len(sample["tokens"]) + + def add_sample_to_pack(self, pack: Dict[str, List[int]], sample: Dict[str, List[int]], next_doc_id: int) -> int: + seq_len = len(sample["tokens"]) + + # Append sample data to the pack + pack["tokens"].extend(sample["tokens"]) + pack["labels"].extend(sample["labels"]) + pack["document_ids"].extend([next_doc_id] * seq_len) + pack["input_pos"].extend(range(seq_len)) # input_pos restarts for each doc + + # Increment doc ID for the next sample + return 1 + + def finalize_pack( + self, pack: Dict[str, List[int]], target_tokens_per_pack: int, next_doc_id: int + ) -> PackType: + current_size = len(pack["tokens"]) + num_padding = target_tokens_per_pack - current_size + + if num_padding > 0: + pack["tokens"].extend([self.padding_idx] * num_padding) + pack["labels"].extend([self.ignore_idx] * num_padding) + pack["input_pos"].extend([0] * num_padding) + pack["document_ids"].extend([next_doc_id] * num_padding) + + return { + "tokens": torch.tensor(pack["tokens"], dtype=torch.long), + "labels": torch.tensor(pack["labels"], dtype=torch.long), + "document_ids": torch.tensor(pack["document_ids"], dtype=torch.long), + "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), + } + + def _mask_mod( + self, + b: int, + h: int, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + doc_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Standard block-causal mask logic. Tokens can only attend to other + tokens within the same document, respecting causality. + """ + causal_mask = q_idx >= kv_idx + document_mask = doc_ids[b, q_idx] == doc_ids[b, kv_idx] + return causal_mask & document_mask \ No newline at end of file From 6e14b0622ac4d46a8cf6bf3146ab44ecbf385dc9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 14:23:24 -0400 Subject: [PATCH 04/48] enable on full recipe --- recipes/configs/llama3_2/3B_full.yaml | 8 ++- recipes/full_finetune_distributed.py | 94 ++++++++++++++++++++------- 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index bb765f1917..c5ca1f708b 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -24,17 +24,21 @@ output_dir: /tmp/torchtune/llama3_2_3B/full # /tmp may be deleted by your system tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model - max_seq_len: null + max_seq_len: 4096 # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False # True increases speed split: train[:95%] seed: null shuffle: True batch_size: 4 +# On-the-fly packing strategy +# Set packing_strategy: null to disable packing +packing_strategy: + _component_: torchtune.datasets.TextPackingStrategy + # Validation run_val_every_n_steps: null # Change to an integer to enable validation every N steps dataset_val: diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index d33d2dad31..76d6fee1a4 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -25,8 +25,8 @@ from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import padded_collate_packed -from torchtune.datasets import ConcatDataset +from torchtune.data import collate_packed, padded_collate_packed +from torchtune.datasets import ConcatDataset, IterablePackedDataset, SFTDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.recipe_interfaces import FTRecipeInterface @@ -45,6 +45,7 @@ convert_to_float8_training, is_fp8_tensorwise_scaling, ) +from torch.utils.data import IterableDataset from tqdm import tqdm @@ -428,12 +429,19 @@ def setup(self, cfg: DictConfig) -> None: # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized - collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + collate_name = cfg.get("collate_fn", None) + if collate_name is None: + if cfg.get("packing_strategy") is not None: + collate_name = "torchtune.data.collate_packed" + else: + collate_name = "torchtune.data.padded_collate_sft" + self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, collate_fn=collate_name, + cfg_packing_strategy=cfg.get("packing_strategy"), ) # Setup validation dataloader if validation dataset is provided @@ -444,6 +452,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_dataset=cfg.dataset_val, batch_size=batch_size_val, collate_fn=collate_name, + cfg_packing_strategy=cfg.get("packing_strategy"), shuffle=False, ) @@ -778,6 +787,7 @@ def _setup_data( shuffle: bool, batch_size: int, collate_fn: str, + cfg_packing_strategy: Optional[DictConfig] = None, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: """ @@ -785,40 +795,74 @@ def _setup_data( map-style datasets. If a state_dict is provided (meaning we are resuming a training run), it is loaded into the dataloader. """ + # 1. Instantiate the base map-style dataset if isinstance(cfg_dataset, ListConfig): datasets = [ config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) - - # Instantiate collate_fn - if "left_pad_sequence" in collate_fn: - raise RuntimeError("left_pad_sequence collator is only for inference.") - collate_fn = _get_component_from_path(collate_fn) sampler = StatefulDistributedSampler( ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 ) + + # 2. Set up dataset, sampler and collate function based on packing strategy + # NOTE: This is a temporary hack to make it work with the new packing strategy + if cfg_packing_strategy: + if self._is_rank_zero: + self._logger.info("Using IterablePackedDataset for on-the-fly packing.") + + packing_strategy = config.instantiate( + cfg_packing_strategy, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + + # Wrapper to make map-style dataset compatible with IterablePackedDataset + class _SamplerWrapper(IterableDataset): + def __init__(self, data, sampler): + self._data = data + self._sampler = sampler + + def __iter__(self): + for i in self._sampler: + yield self._data[i] + + iterable_ds = _SamplerWrapper(ds, sampler) + + final_ds = IterablePackedDataset( + dataset=iterable_ds, + strategy=packing_strategy, + target_tokens_per_pack=self._tokenizer.max_seq_len, + ) + + # Sampler must be None for iterable datasets + sampler = None + + collate_callable = partial( + _get_component_from_path(collate_fn), + mask_fn=packing_strategy.create_block_mask, + device=self._device, + ) + else: # Fallback for non-packed datasets + + final_ds = ds + + collate_callable = partial( + _get_component_from_path(collate_fn), + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor, + ) + dataloader = StatefulDataLoader( - dataset=ds, + dataset=final_ds, batch_size=batch_size, sampler=sampler, - collate_fn=( - partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor, - ) - if not packed - else padded_collate_packed - ), - # dropping last avoids shape issues with compile + flex attention + collate_fn=collate_callable, drop_last=True, ) @@ -913,7 +957,11 @@ def train(self) -> None: # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) - self._dataloader.sampler.set_epoch(curr_epoch) + # NOTE: Temporary hack to make it work with the new packing strategy + if self._dataloader.sampler is not None: + self._dataloader.sampler.set_epoch(curr_epoch) + else: + self._dataloader.dataset.dataset.sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): # Start tracking CUDA memory for active steps for just the first epoch if ( From f9db469f5e4f9dc66d246ee42b574f4ddfeaeb63 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 13:43:14 -0700 Subject: [PATCH 05/48] fix imports + formatting --- recipes/full_finetune_distributed.py | 9 ++- torchtune/data/__init__.py | 2 + torchtune/data/_collate.py | 3 +- torchtune/datasets/__init__.py | 6 ++ torchtune/datasets/_iterable_packed.py | 92 ++++++++++++++++---------- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 76d6fee1a4..2e2e9be0d7 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -20,13 +20,13 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer +from torch.utils.data import IterableDataset from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import collate_packed, padded_collate_packed -from torchtune.datasets import ConcatDataset, IterablePackedDataset, SFTDataset +from torchtune.datasets import ConcatDataset, IterablePackedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.recipe_interfaces import FTRecipeInterface @@ -45,7 +45,6 @@ convert_to_float8_training, is_fp8_tensorwise_scaling, ) -from torch.utils.data import IterableDataset from tqdm import tqdm @@ -808,7 +807,7 @@ def _setup_data( sampler = StatefulDistributedSampler( ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 ) - + # 2. Set up dataset, sampler and collate function based on packing strategy # NOTE: This is a temporary hack to make it work with the new packing strategy if cfg_packing_strategy: @@ -850,7 +849,7 @@ def __iter__(self): else: # Fallback for non-packed datasets final_ds = ds - + collate_callable = partial( _get_component_from_path(collate_fn), padding_idx=self._tokenizer.pad_id, diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index a75e16780a..c093fe08b5 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from torchtune.data._collate import ( + collate_packed, left_pad_sequence, padded_collate, padded_collate_dpo, @@ -59,5 +60,6 @@ "padded_collate", "padded_collate_tiled_images_and_mask", "padded_collate_packed", + "collate_packed", "load_image", ] diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 015eb2761c..645bbac45d 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional, Union, List, Dict +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -38,6 +38,7 @@ def collate_packed( return collated + def left_pad_sequence( sequences: list[torch.Tensor], batch_first: bool = False, diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index b0c7c11738..1d29a0949a 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -12,6 +12,10 @@ from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset +from torchtune.datasets._iterable_packed import ( + IterablePackedDataset, + TextPackingStrategy, +) from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset @@ -44,4 +48,6 @@ "SFTDataset", "hh_rlhf_helpful_dataset", "multimodal", + "IterablePackedDataset", + "TextPackingStrategy", ] diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index ba535ee6ba..5e83161b67 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -1,15 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. import logging from abc import ABC, abstractmethod from collections import deque -from functools import partial -from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, TypeVar +from typing import Any, Generic, Iterable, Iterator, Optional, TypeVar import torch from torch.nn.attention.flex_attention import ( create_block_mask as create_block_mask_flex, ) -from torch.utils.data import IterableDataset, Stateful +from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import Stateful from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION @@ -17,15 +22,15 @@ logger = logging.getLogger(__name__) SampleType = TypeVar("SampleType") -PackType = Dict[str, torch.Tensor] +PackType = dict[str, torch.Tensor] class PackingStrategy(ABC, Generic[SampleType]): + """ + Strategy to be used in IterablePackedDataset and with FlexAttention. + """ + def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): - """ - Initializes a strategy to be used in IterablePackedDataset. All strategies are meant to be used - with FlexAttention by leveraging 'mask_mod' to create the attention mask. - """ if not _SUPPORTS_FLEX_ATTENTION: raise RuntimeError( "The IterablePackedDataset and its strategies require Flex Attention support, " @@ -35,9 +40,13 @@ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX) self.ignore_idx = ignore_idx @abstractmethod - def create_empty_pack(self) -> Dict[str, List[Any]]: + def create_empty_pack(self) -> dict[str, list[Any]]: """ Creates an empty pack. + + Returns: + dict[str, list[Any]]: An empty dictionary with lists as values. + Example: self.create_empty_pack() >>> {"tokens": [], "labels": [], "document_ids": [], "input_pos": []} @@ -48,6 +57,13 @@ def create_empty_pack(self) -> Dict[str, List[Any]]: def get_sample_size(self, sample: SampleType) -> int: """ Returns the size of a sample. + + Args: + sample (SampleType): The sample to get the size of. + + Returns: + int: The size of the sample. + Example: # for a sample with 100 tokens self.get_sample_size(sample) @@ -57,13 +73,13 @@ def get_sample_size(self, sample: SampleType) -> int: @abstractmethod def add_sample_to_pack( - self, pack: Dict[str, List[Any]], sample: SampleType, next_doc_id: int + self, pack: dict[str, list[Any]], sample: SampleType, next_doc_id: int ) -> int: """ Adds a sample to the pack dictionary in-place. Args: - pack (Dict[str, List[Any]]): The dictionary representing the pack, to be modified in-place. + pack (dict[str, list[Any]]): The dictionary representing the pack, to be modified in-place. sample (SampleType): The sample to add. next_doc_id (int): The starting document ID to use for this sample. @@ -86,13 +102,13 @@ def add_sample_to_pack( @abstractmethod def finalize_pack( - self, pack: Dict[str, List[Any]], target_tokens_per_pack: int, next_doc_id: int + self, pack: dict[str, list[Any]], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: """ Finalizes a pack, primarily by padding it to the target length. Args: - pack (Dict[str, List[Any]]): The pack data. + pack (dict[str, list[Any]]): The pack data. target_tokens_per_pack (int): The target length to pad to. next_doc_id (int): The document ID to use for the padding tokens. @@ -133,13 +149,13 @@ def _mask_mod( Args: b (int): Batch index. h (int): Head index. - q_idx (Tensor): Query indices. - kv_idx (Tensor): Key/value indices. - doc_ids (Tensor): The complete document ID tensor for the batch, + q_idx (torch.Tensor): Query indices. + kv_idx (torch.Tensor): Key/value indices. + doc_ids (torch.Tensor): The complete document ID tensor for the batch, of shape (batch_size, seq_len). Returns: - A boolean tensor indicating which query/key pairs are allowed to attend. + torch.Tensor: A boolean tensor indicating which query/key pairs are allowed to attend. """ pass @@ -161,6 +177,17 @@ def _mask_mod_for_flex(b, h, q_idx, kv_idx): class IterablePackedDataset(IterableDataset[PackType], Stateful, Generic[SampleType]): + """ + IterablePackedDataset takes any IterableDataset and a PackingStrategy, packs documents until + the 'target_tokens_per_pack' is reached and yields a dictionary of tensors. + + Args: + dataset (IterableDataset[SampleType]): The IterableDataset to pack. + strategy (PackingStrategy[SampleType]): The PackingStrategy to use for packing. + target_tokens_per_pack (int): The target number of tokens per pack. + buffer_size (int): The size of the buffer to use for packing. + """ + def __init__( self, dataset: IterableDataset[SampleType], @@ -168,16 +195,6 @@ def __init__( target_tokens_per_pack: int, buffer_size: int = 50, ): - """ - IterablePackedDataset takes any IterableDataset and a PackingStrategy, packs documents until - the 'target_tokens_per_pack' is reached and yields a dictionary of tensors. - - Args: - dataset (IterableDataset[SampleType]): The IterableDataset to pack. - strategy (PackingStrategy[SampleType]): The PackingStrategy to use for packing. - target_tokens_per_pack (int): The target number of tokens per pack. - buffer_size (int): The size of the buffer to use for packing. - """ self.dataset = dataset self.strategy = strategy self.target_tokens_per_pack = target_tokens_per_pack @@ -244,7 +261,7 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]: Example: self._buffer = deque([(sample1, 200), (sample2, 100), (sample3, 48), (sample4, 200)]) - + # First iteration: selected_sample_idx = self._find_next_fitting_sample(remaining_size=150) # returns 1 del self._buffer[selected_sample_idx] @@ -343,7 +360,7 @@ def __iter__(self) -> Iterator[PackType]: elif self._exhausted and not self._buffer: break - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """ Get the state of the packer. It relies on the input dataset to save the progress of iteration. It does NOT save the internal buffer or any partially built pack. @@ -356,7 +373,7 @@ def state_dict(self) -> Dict[str, Any]: return state - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ Load the state of the packer. This restores the state of the underlying dataset. The buffer and any partially-built pack are discarded. @@ -369,7 +386,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._reset_packer_state() self._resuming = True -class TextPackingStrategy(PackingStrategy[Dict[str, List[int]]]): + +class TextPackingStrategy(PackingStrategy[dict[str, list[int]]]): """ Strategy for packing standard text samples for causal language modeling. It is designed to be used with the IterablePackedDataset. @@ -381,7 +399,7 @@ class TextPackingStrategy(PackingStrategy[Dict[str, List[int]]]): def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) - def create_empty_pack(self) -> Dict[str, List[int]]: + def create_empty_pack(self) -> dict[str, list[int]]: return { "tokens": [], "labels": [], @@ -389,10 +407,12 @@ def create_empty_pack(self) -> Dict[str, List[int]]: "input_pos": [], } - def get_sample_size(self, sample: Dict[str, List[int]]) -> int: + def get_sample_size(self, sample: dict[str, list[int]]) -> int: return len(sample["tokens"]) - def add_sample_to_pack(self, pack: Dict[str, List[int]], sample: Dict[str, List[int]], next_doc_id: int) -> int: + def add_sample_to_pack( + self, pack: dict[str, list[int]], sample: dict[str, list[int]], next_doc_id: int + ) -> int: seq_len = len(sample["tokens"]) # Append sample data to the pack @@ -405,7 +425,7 @@ def add_sample_to_pack(self, pack: Dict[str, List[int]], sample: Dict[str, List[ return 1 def finalize_pack( - self, pack: Dict[str, List[int]], target_tokens_per_pack: int, next_doc_id: int + self, pack: dict[str, list[int]], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: current_size = len(pack["tokens"]) num_padding = target_tokens_per_pack - current_size @@ -437,4 +457,4 @@ def _mask_mod( """ causal_mask = q_idx >= kv_idx document_mask = doc_ids[b, q_idx] == doc_ids[b, kv_idx] - return causal_mask & document_mask \ No newline at end of file + return causal_mask & document_mask From ff6fdbed8b0761d455bbf8983066577bbc744179 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 17:28:39 -0400 Subject: [PATCH 06/48] add max_steps_per_epoch requirement --- recipes/full_finetune_distributed.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 2e2e9be0d7..be75b2a5d2 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -462,14 +462,23 @@ def setup(self, cfg: DictConfig) -> None: # by the dataloader, the max_steps_per_epoch param set by the user and the # gradient_accumulation_steps param. This value is used for logging and tracking # training state. The computation should happen after the dataloader has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): + + # NOTE: Hack to get it running. needs to be properly addressesd. + if isinstance(self._dataloader.dataset, IterableDataset): + if self.max_steps_per_epoch is None: + raise ValueError( + "max_steps_per_epoch must be specified for iterable datasets." + ) self._steps_per_epoch = self.max_steps_per_epoch + else: + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch # Setup lr scheduler From 5e447ab601d1c8ebe9368360cbd8bd7eb8043762 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 14:53:34 -0700 Subject: [PATCH 07/48] address blockers --- recipes/full_finetune_distributed.py | 5 +---- torchtune/modules/transformer.py | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index be75b2a5d2..bdfcc3d553 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -966,10 +966,7 @@ def train(self) -> None: for curr_epoch in range(self.epochs_run, self.total_epochs): pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) # NOTE: Temporary hack to make it work with the new packing strategy - if self._dataloader.sampler is not None: - self._dataloader.sampler.set_epoch(curr_epoch) - else: - self._dataloader.dataset.dataset.sampler.set_epoch(curr_epoch) + self._dataloader.dataset.dataset._sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): # Start tracking CUDA memory for active steps for just the first epoch if ( diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 724138b14e..68a62d9d31 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -580,6 +580,7 @@ def forward( encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, input_embeds: Optional[torch.Tensor] = None, + **kwargs: dict, ) -> Union[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -611,6 +612,7 @@ def forward( This parameter is required during inference if caches have been setup. Default is None. input_embeds (Optional[torch.Tensor]): Pass these instead of tokens to short-circuit token embeddings and skip straight to the transformer layers. Shape ``[b x s x d]``. Default: None + **kwargs (dict): Keyword arguments to pass to the transformer layers. Returns: Union[torch.Tensor, list[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_layer=False` From 13cda28e7f2af72873ff118b70925f69d30f0b7f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:10:16 -0400 Subject: [PATCH 08/48] small fixes --- torchtune/data/_collate.py | 11 +- torchtune/datasets/_iterable_packed.py | 139 ++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 645bbac45d..6faa88774f 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -26,10 +26,13 @@ def collate_packed( # Assumes all samples in the batch have the same keys, which are all tensors. keys_to_stack = batch[0].keys() - collated = { - key: torch.stack([sample[key] for sample in batch], dim=0) - for key in keys_to_stack - } + collated = {} + for key in keys_to_stack: + if isinstance(batch[0][key], torch.Tensor): + collated[key] = torch.stack([sample[key] for sample in batch], dim=0) + else: + # TODO: Remove? i dont see a situation where it would not be a tensor. + collated[key] = [sample[key] for sample in batch] # Delegate mask creation to the provided specialized function # TODO: investigate the need for device here. Currently we hardcode it in utilities to cuda. diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 5e83161b67..d367e0ccb4 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -88,7 +88,7 @@ def add_sample_to_pack( Example: pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} - sample = {"tokens": [5, 6], "labels": [7, 8], "document_ids": [1, 1], "input_pos": [0, 1]} + sample = {"tokens": [5, 6], "labels": [7, 8]} added_docs = self.add_sample_to_pack(pack, sample, next_doc_id=1) print(pack) >>> {"tokens": [1, 2, 5, 6], @@ -458,3 +458,140 @@ def _mask_mod( causal_mask = q_idx >= kv_idx document_mask = doc_ids[b, q_idx] == doc_ids[b, kv_idx] return causal_mask & document_mask + + + +# NOTE: For demonstration purposes only. + +class DPOPackingStrategy(PackingStrategy[dict[str, list[int]]]): + """ + Strategy for packing DPO samples with a shared prompt. It packs a DPO + sample as three logical documents: a shared prompt, a chosen response, + and a rejected response. This structure is encoded in the `document_ids` + metadata, allowing the strategy to build the correct attention pattern + (e.g., both responses can attend to the prompt, but not to each other). + + ASSUMPTION: The input DPO sample dict contains pre-tokenized: + - "prompt_ids" + - "chosen_response_only_ids" + - "chosen_response_only_labels" + - "rejected_response_only_ids" + - "rejected_response_only_labels" + """ + + def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): + super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + + def create_empty_pack(self) -> dict[str, list[int]]: + return { + "tokens": [], + "labels": [], + "document_ids": [], + "input_pos": [], + "chosen_response_mask": [], + "rejected_response_mask": [], + } + + def get_sample_size(self, sample: dict[str, list[int]]) -> int: + # The total size of one DPO sample is the shared prompt + both responses. + return ( + len(sample["prompt_ids"]) + + len(sample["chosen_response_only_ids"]) + + len(sample["rejected_response_only_ids"]) + ) + + def add_sample_to_pack( + self, pack: dict[str, list[int]], sample: dict[str, list[int]], next_doc_id: int + ) -> int: + # Assign a unique doc ID triplet for (prompt, chosen, rejected) + prompt_doc_id = next_doc_id + chosen_doc_id = next_doc_id + 1 + rejected_doc_id = next_doc_id + 2 + + prompt_ids = sample["prompt_ids"] + chosen_ids = sample["chosen_response_only_ids"] + rejected_ids = sample["rejected_response_only_ids"] + + # Input positions restart from 0 for each new DPO sample in the pack + total_len = len(prompt_ids) + len(chosen_ids) + len(rejected_ids) + pack["input_pos"].extend(range(total_len)) + + # 1. Add Shared Prompt data + pack["tokens"].extend(prompt_ids) + pack["labels"].extend([self.ignore_idx] * len(prompt_ids)) + pack["document_ids"].extend([prompt_doc_id] * len(prompt_ids)) + pack["chosen_response_mask"].extend([False] * len(prompt_ids)) + pack["rejected_response_mask"].extend([False] * len(prompt_ids)) + + # 2. Add Chosen Response data + pack["tokens"].extend(chosen_ids) + pack["labels"].extend(sample["chosen_response_only_labels"]) + pack["document_ids"].extend([chosen_doc_id] * len(chosen_ids)) + pack["chosen_response_mask"].extend([True] * len(chosen_ids)) + pack["rejected_response_mask"].extend([False] * len(chosen_ids)) + + # 3. Add Rejected Response data + pack["tokens"].extend(rejected_ids) + pack["labels"].extend(sample["rejected_response_only_labels"]) + pack["document_ids"].extend([rejected_doc_id] * len(rejected_ids)) + pack["chosen_response_mask"].extend([False] * len(rejected_ids)) + pack["rejected_response_mask"].extend([True] * len(rejected_ids)) + + # Advance the document ID counter by 3 for the next DPO sample. + return 3 + + def finalize_pack( + self, pack: dict[str, list[int]], target_tokens_per_pack: int, next_doc_id: int + ) -> dict[str, torch.Tensor]: + current_size = len(pack["tokens"]) + num_padding = target_tokens_per_pack - current_size + + if num_padding > 0: + pack["tokens"].extend([self.padding_idx] * num_padding) + pack["labels"].extend([self.ignore_idx] * num_padding) + pack["input_pos"].extend([0] * num_padding) + pack["chosen_response_mask"].extend([False] * num_padding) + pack["rejected_response_mask"].extend([False] * num_padding) + pack["document_ids"].extend([next_doc_id] * num_padding) + + return { + "tokens": torch.tensor(pack["tokens"], dtype=torch.long), + "labels": torch.tensor(pack["labels"], dtype=torch.long), + "document_ids": torch.tensor(pack["document_ids"], dtype=torch.long), + "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), + "chosen_response_mask": torch.tensor( + pack["chosen_response_mask"], dtype=torch.bool + ), + "rejected_response_mask": torch.tensor( + pack["rejected_response_mask"], dtype=torch.bool + ), + } + + def _mask_mod( + self, + b: int, + h: int, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + doc_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Mask logic for DPO. + - Causal self-attention within the same document. + - Cross-attention from response tokens (chosen/rejected) to their + corresponding prompt tokens. + """ + q_doc = doc_ids[b, q_idx] + kv_doc = doc_ids[b, kv_idx] + + # 1. Document-level Causal self-attention + is_same_doc = q_doc == kv_doc + self_attention_mask = is_same_doc & (q_idx >= kv_idx) + + # 2. Cross-attention from response to prompt + q_prompt_doc_id = (q_doc // 3) * 3 + kv_is_part_of_q_prompt = kv_doc == q_prompt_doc_id + q_is_response = (q_doc % 3) > 0 + cross_attention_mask = q_is_response & kv_is_part_of_q_prompt + + return self_attention_mask | cross_attention_mask From d26769c5a353b16a6722ca6bd02f82f66912f523 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:10:21 -0400 Subject: [PATCH 09/48] add md doc --- planning/ontheflypacking.md | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 planning/ontheflypacking.md diff --git a/planning/ontheflypacking.md b/planning/ontheflypacking.md new file mode 100644 index 0000000000..cb1c0008dc --- /dev/null +++ b/planning/ontheflypacking.md @@ -0,0 +1,51 @@ +### What: +Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask. + +Example: +```python +# The current pack with one sample +pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} + +# The next sample to be added +sample = {"tokens": [5, 6], "labels": [7, 8]} + +# After adding the sample +added_docs = add_sample_to_pack(pack, sample, next_doc_id=1) +print(pack) +>>> {"tokens": [1, 2, 5, 6], + "labels": [3, 4, 7, 8], + "document_ids": [0, 0, 1, 1], + "input_pos": [0, 1, 0, 1]} + +create_block_causal_mask(document_ids) +>>> [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ] +``` + +### Goal: +0) Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes. + +### Context: +1) We currently have map-style packing. We pre-process the entire dataset before training starts, which is not scalable. +2) Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc. +3) Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else. + +### Solution: +4) Implement a new on-the-fly packing that takes any iterable dataset as input; +5) Packing contract consists of + i) a `PackingStrategy` that defines how a) to pack and b) the **_mask_mod** used for flex attention; + ii) a `IterablePackedDataset` that takes any a) `PackingStrategy`, b) **iterable dataset** as input and yields packed samples; + iii) a `packed_collate_fn` that takes the batch of packed samples and a **mask_fn** (e.g. `strategy.create_block_mask`) to generate the attention mask on the fly. + To define a new packing strategy, the user only needs to implement the `PackingStrategy` class. + +### Implementation: +6) Updated `full_finetune_distributed.py` to use `IterablePackedDataset` when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC. + +### Not in this PR: +7) **Logging**: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.) +8) **Packing-aware Loss**: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it. +9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing; \ No newline at end of file From 59b8cab26191b5eba7d311011c56109c5cff7877 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:18:44 -0400 Subject: [PATCH 10/48] update comments --- recipes/full_finetune_distributed.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 11e7da7300..e55fd034a0 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -444,7 +444,7 @@ def setup(self, cfg: DictConfig) -> None: shuffle=cfg.shuffle, batch_size=cfg.batch_size, collate_fn=collate_name, - cfg_packing_strategy=cfg.get("packing_strategy"), + cfg_packing_strategy=cfg.get("packing_strategy", None), ) # Setup validation dataloader if validation dataset is provided @@ -455,7 +455,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_dataset=cfg.dataset_val, batch_size=batch_size_val, collate_fn=collate_name, - cfg_packing_strategy=cfg.get("packing_strategy"), + cfg_packing_strategy=cfg.get("packing_strategy", None), shuffle=False, ) @@ -467,7 +467,7 @@ def setup(self, cfg: DictConfig) -> None: # gradient_accumulation_steps param. This value is used for logging and tracking # training state. The computation should happen after the dataloader has been setup - # NOTE: Hack to get it running. needs to be properly addressesd. + # NOTE: Hack to get it running. needs to be properly addressed. if isinstance(self._dataloader.dataset, IterableDataset): if self.max_steps_per_epoch is None: raise ValueError( @@ -800,7 +800,7 @@ def _setup_data( map-style datasets. If a state_dict is provided (meaning we are resuming a training run), it is loaded into the dataloader. """ - # 1. Instantiate the base map-style dataset + # 1. Instantiate the base map-style dataset (to be replaced with IterableDataset) if isinstance(cfg_dataset, ListConfig): datasets = [ config.instantiate(single_cfg_dataset, self._tokenizer) @@ -814,8 +814,7 @@ def _setup_data( ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 ) - # 2. Set up dataset, sampler and collate function based on packing strategy - # NOTE: This is a temporary hack to make it work with the new packing strategy + # 2. Set up packing if cfg_packing_strategy: if self._is_rank_zero: self._logger.info("Using IterablePackedDataset for on-the-fly packing.") @@ -826,7 +825,9 @@ def _setup_data( ignore_idx=self._loss_fn.ignore_index, ) - # Wrapper to make map-style dataset compatible with IterablePackedDataset + # NOTE: This is a temporary hack to make map-style dataset + # compatible with IterablePackedDataset + # ------------------------------------------------------------ class _SamplerWrapper(IterableDataset): def __init__(self, data, sampler): self._data = data @@ -838,15 +839,16 @@ def __iter__(self): iterable_ds = _SamplerWrapper(ds, sampler) + # Sampler must be None for iterable datasets + sampler = None + # ------------------------------------------------------------ + final_ds = IterablePackedDataset( dataset=iterable_ds, strategy=packing_strategy, target_tokens_per_pack=self._tokenizer.max_seq_len, ) - # Sampler must be None for iterable datasets - sampler = None - collate_callable = partial( _get_component_from_path(collate_fn), mask_fn=packing_strategy.create_block_mask, From 5d7d4964dac6fec68b04243b48842661dd308343 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:21:13 -0400 Subject: [PATCH 11/48] update comments --- planning/ontheflypacking.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/planning/ontheflypacking.md b/planning/ontheflypacking.md index cb1c0008dc..540beff014 100644 --- a/planning/ontheflypacking.md +++ b/planning/ontheflypacking.md @@ -30,14 +30,14 @@ create_block_causal_mask(document_ids) 0) Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes. ### Context: -1) We currently have map-style packing. We pre-process the entire dataset before training starts, which is not scalable. +1) We currently have map-style packing. We pre-process the dataset before training, which is not scalable. 2) Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc. 3) Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else. ### Solution: 4) Implement a new on-the-fly packing that takes any iterable dataset as input; 5) Packing contract consists of - i) a `PackingStrategy` that defines how a) to pack and b) the **_mask_mod** used for flex attention; + i) a `PackingStrategy` that defines a) how to pack and b) the **_mask_mod** used for flex attention; ii) a `IterablePackedDataset` that takes any a) `PackingStrategy`, b) **iterable dataset** as input and yields packed samples; iii) a `packed_collate_fn` that takes the batch of packed samples and a **mask_fn** (e.g. `strategy.create_block_mask`) to generate the attention mask on the fly. To define a new packing strategy, the user only needs to implement the `PackingStrategy` class. From e19392682f70d72337b3c3f20dc659471b55ab0e Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:39:17 -0400 Subject: [PATCH 12/48] update comment --- planning/ontheflypacking.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/planning/ontheflypacking.md b/planning/ontheflypacking.md index 540beff014..18279bac47 100644 --- a/planning/ontheflypacking.md +++ b/planning/ontheflypacking.md @@ -48,4 +48,5 @@ create_block_causal_mask(document_ids) ### Not in this PR: 7) **Logging**: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.) 8) **Packing-aware Loss**: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it. -9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing; \ No newline at end of file +9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing; +10) **tokenization**: For advanced packing, e.g. shared prompts in GRPO/DPO, we will need extra metadata from upstream datasets, e.g. prompt len. \ No newline at end of file From 40d79f450b42b11a613dcf7fdb0d3bf7f9dc9253 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 11 Jun 2025 20:47:35 -0400 Subject: [PATCH 13/48] update comment --- recipes/full_finetune_distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index e55fd034a0..2cdc39495c 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -467,7 +467,8 @@ def setup(self, cfg: DictConfig) -> None: # gradient_accumulation_steps param. This value is used for logging and tracking # training state. The computation should happen after the dataloader has been setup - # NOTE: Hack to get it running. needs to be properly addressed. + # NOTE: Hack to get it running. Iterable doesnt allow len(dataloader) + # ------------------------------------------------------------ if isinstance(self._dataloader.dataset, IterableDataset): if self.max_steps_per_epoch is None: raise ValueError( From 3cab5334eda3642aff2c8fbf5a19482b1d6ece52 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:41:30 -0400 Subject: [PATCH 14/48] first commit --- recipes/configs/llama3_2/3B_full.yaml | 40 +- recipes/full_finetune_distributed.py | 500 ++++++++---------- .../torchtune/data/test_metrics_aggregator.py | 149 ++++++ .../torchtune/data/test_metrics_transform.py | 54 ++ tests/torchtune/datasets/test_hf_iterable.py | 339 ++++++++++++ tests/torchtune/datasets/test_interleaved.py | 162 ++++++ torchtune/data/__init__.py | 12 + torchtune/data/_aggregator.py | 342 ++++++++++++ torchtune/data/_metrics.py | 95 ++++ torchtune/datasets/__init__.py | 43 +- torchtune/datasets/_alpaca.py | 65 ++- torchtune/datasets/_hf_iterable.py | 271 ++++++++++ torchtune/datasets/_interleaved.py | 115 ++++ torchtune/datasets/_iterable_base.py | 37 ++ torchtune/datasets/_sft.py | 97 +++- torchtune/datasets/_slimorca.py | 69 ++- .../checkpointing/_checkpoint_client.py | 6 + 17 files changed, 2088 insertions(+), 308 deletions(-) create mode 100644 tests/torchtune/data/test_metrics_aggregator.py create mode 100644 tests/torchtune/data/test_metrics_transform.py create mode 100644 tests/torchtune/datasets/test_hf_iterable.py create mode 100644 tests/torchtune/datasets/test_interleaved.py create mode 100644 torchtune/data/_aggregator.py create mode 100644 torchtune/data/_metrics.py create mode 100644 torchtune/datasets/_hf_iterable.py create mode 100644 torchtune/datasets/_interleaved.py create mode 100644 torchtune/datasets/_iterable_base.py diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index bb765f1917..5534b305ac 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -26,21 +26,28 @@ tokenizer: path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model max_seq_len: null -# Dataset and Sampler +# Dataloader +dataloader: + batch_size: 4 + # num_workers and pin_memory can be added here if needed + +# Dataset - now a list to support multiple weighted sources dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False # True increases speed - split: train[:95%] + - _component_: torchtune.datasets.slimorca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.8 + - _component_: torchtune.datasets.alpaca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.2 + +# Packing (TBD by follow up PR) +# packing: +# _component_: torchtune.datasets.packing.SFTPacking +# max_seq_len: 8192 + seed: null -shuffle: True -batch_size: 4 -# Validation -run_val_every_n_steps: null # Change to an integer to enable validation every N steps -dataset_val: - _component_: torchtune.datasets.alpaca_cleaned_dataset - split: train[95%:] -batch_size_val: ${batch_size} +# Validation not supported yet with iterable datasets # Model Arguments model: @@ -65,10 +72,11 @@ optimizer: loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 8 # Use to increase effective batch size +# Training - now step-based +num_training_steps: 100 # Total number of training steps to run +save_every_n_steps: 200 # Save a checkpoint every N steps. Using 200 to avoid ckpt. +gradient_accumulation_steps: 1 +dataset_metrics_log_freq: 10 # Log dataset-specific metrics every N steps # Environment device: cuda diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 10e0aaeb24..f34ccc6a7e 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -9,7 +9,7 @@ import time from functools import partial -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union from warnings import warn import torch @@ -25,8 +25,8 @@ from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import padded_collate_packed -from torchtune.datasets import ConcatDataset +from torchtune.data import padded_collate_packed, MetricsAggregator +from torchtune.datasets import ConcatDataset, InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils @@ -207,7 +207,7 @@ def __init__(self, cfg: DictConfig) -> None: self._checkpoint_client = CheckpointClient(cfg) self._enable_fp8_training = cfg.get("enable_fp8_training", False) self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) - self.save_every_n_steps = cfg.get("save_every_n_steps") + self.save_every_n_steps = cfg.get("save_every_n_steps", None) self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) if self._run_val_every_n_steps is not None: @@ -273,18 +273,20 @@ def __init__(self, cfg: DictConfig) -> None: self.seed = training.set_seed( seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) ) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - + + # Step-based training support + self.num_training_steps = cfg.num_training_steps + self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) + self._metrics_aggregator = None # Will be initialized in setup + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] - self.global_step = ckpt_dict[training.STEPS_KEY] + # The new format stores steps directly + self.global_step = ckpt_dict["steps_run"] # on mismatch, warn the user and prevent the override if self.seed != ckpt_dict[training.SEED_KEY]: @@ -295,23 +297,6 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: ) ) self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: - warn( - message=( - "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" - ) - ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] - - # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: - warn( - message=( - "Config value for total_epochs does not match the checkpoint value, " - f"using the config value: {self.total_epochs}" - ) - ) except KeyError as e: raise KeyError( @@ -324,6 +309,9 @@ def setup(self, cfg: DictConfig) -> None: Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ + if cfg.get("dataset_val") is not None: + raise NotImplementedError("Validation is not supported yet with iterable datasets.") + if self.fsdp_cpu_offload: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU @@ -434,12 +422,14 @@ def setup(self, cfg: DictConfig) -> None: utils.log_rank_zero(self._logger, "Loss is initialized.") + # Initialize metrics aggregator for dataset metrics tracking + self._metrics_aggregator = MetricsAggregator() + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, batch_size=cfg.batch_size, collate_fn=collate_name, dataloader_state_dict=( @@ -457,7 +447,6 @@ def setup(self, cfg: DictConfig) -> None: cfg_dataset=cfg.dataset_val, batch_size=batch_size_val, collate_fn=collate_name, - shuffle=False, dataloader_state_dict=( state_dict[training.VAL_DATALOADER_KEY] if training.VAL_DATALOADER_KEY in state_dict @@ -465,38 +454,13 @@ def setup(self, cfg: DictConfig) -> None: ), ) - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - # - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader, the max_steps_per_epoch param set by the user and the - # gradient_accumulation_steps param. This value is used for logging and tracking - # training state. The computation should happen after the dataloader has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - - if self.save_every_n_steps is None: - self.save_every_n_steps = self._steps_per_epoch - self.checkpoint_dir_prefix = "epoch" - else: - self.checkpoint_dir_prefix = "step" - - if ( - self._resume_from_checkpoint - and self.global_step % self._steps_per_epoch == 0 - ): - list(self._dataloader) + # Set checkpoint dir prefix to step-based + self.checkpoint_dir_prefix = "step" # Setup lr scheduler self._lr_scheduler = self._setup_lr_scheduler( cfg_lr_scheduler=cfg.get("lr_scheduler", None), - num_training_steps=self.total_epochs * self._steps_per_epoch, + num_training_steps=self.num_training_steps, last_epoch=self.global_step - 1, ) @@ -799,53 +763,69 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: DictConfig, - shuffle: bool, + cfg_dataset: Union[DictConfig, ListConfig], batch_size: int, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: """ - All data related setup happens here. This recipe currently supports only - map-style datasets. If a state_dict is provided (meaning we are resuming a training run), - it is loaded into the dataloader. + Set up the dataloader for iterable datasets. """ - if isinstance(cfg_dataset, ListConfig): - datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) - for single_cfg_dataset in cfg_dataset - ] - ds = ConcatDataset(datasets=datasets) - packed = getattr(ds, "packed", False) - else: - ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) - # Instantiate collate_fn - if "left_pad_sequence" in collate_fn: - raise RuntimeError("left_pad_sequence collator is only for inference.") - collate_fn = _get_component_from_path(collate_fn) - - sampler = StatefulDistributedSampler( - ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 + # 1. Create all datasets + iterable_datasets = [] + weights = [] + cfg_dataset_list = cfg_dataset + if not isinstance(cfg_dataset_list, ListConfig): + cfg_dataset_list = [cfg_dataset_list] + + for ds_cfg in cfg_dataset_list: + ds = config.instantiate(ds_cfg, model_transform=self._tokenizer) + iterable_datasets.append(ds) + weights.append(ds_cfg.get("weight", 1.0)) + + # 2. Interleave datasets if any + if len(iterable_datasets) > 1: + ds = InterleavedDataset( + datasets=iterable_datasets, + weights=weights, + seed=self.seed, + ) + else: + ds = iterable_datasets[0] + + # 3. Apply packing + # TODO: follow up PR + packed = False + + # 4. Define a collate function wrapper to handle metrics + base_collate_fn = ( + padded_collate_packed + if packed + else _get_component_from_path(collate_fn) ) + + def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + # TODO: handling of metrics should prob be done in collate_fn. + # putting this here for now to avoid making more changes to this PR. + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + collated_batch = base_collate_fn(clean_batch) + collated_batch["metrics"] = all_metrics + return collated_batch + + # 5. Create DataLoader dataloader = StatefulDataLoader( dataset=ds, batch_size=batch_size, - sampler=sampler, - collate_fn=( - partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor, - ) - if not packed - else padded_collate_packed - ), - # dropping last avoids shape issues with compile + flex attention - drop_last=True, + collate_fn=_collate_with_metrics_wrapper, ) + if dataloader_state_dict is not None: dataloader.load_state_dict(dataloader_state_dict) @@ -917,32 +897,29 @@ def validate(self) -> dict[str, float]: self._model.train() return log_dict - def save_checkpoint(self, *, epoch: int, full_tensors: bool): - if self.global_step % self._steps_per_epoch == 0: - epoch += 1 - + def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): + """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, - optimizer=( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - training_progress=TrainingProgress( + optimizer=(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), + training_progress = TrainingProgress( seed=self.seed, - epochs_run=epoch, - total_epochs=self.total_epochs, - max_steps_per_epoch=self.max_steps_per_epoch, + epochs_run=0, # TODO: not needed. To be deprecated. + total_epochs=1, # TODO: not needed. To be deprecated. + max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. steps_run=self.global_step, - total_training_steps=self.total_epochs * self._steps_per_epoch, + total_training_steps=self.num_training_steps, dataloader_state_dict=self._dataloader.state_dict(), val_dataloader_state_dict=( self._val_dataloader.state_dict() if self._val_dataloader is not None else {} ), + #FIXME: add to load_ckpt and TrainingProgress too + metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), - epoch=epoch, + epoch=epoch, # TODO: not needed. To be deprecated. + step=step, single_device=False, full_tensors=full_tensors, dir_prefix=self.checkpoint_dir_prefix, @@ -968,180 +945,155 @@ def train(self) -> None: num_tokens = 0 self._profiler.start() - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - inner_step_count = self.global_step % self._steps_per_epoch - pbar = tqdm( - initial=inner_step_count, - total=self._steps_per_epoch, - desc=f"{self.epochs_run}|{self.global_step}", - ) - - # Get iterator for the dataloader - self._dataloader.sampler.set_epoch(curr_epoch) - dataloader_iter = iter(self._dataloader) - batch_count = 0 - - # Continue looping until we reach max steps or exhaust the dataset - while inner_step_count < self._steps_per_epoch: - # Try to get the next batch, break if we've reached the end of the dataset - try: - batch = next(dataloader_iter) - except StopIteration: - break - - # Start tracking CUDA memory for active steps for just the first epoch - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps + self.profiler_warmup_steps - and self._device.type == "cuda" - ): - torch.cuda.memory._record_memory_history() - - utils.batch_to_device(batch, self._device) - - # Calculate the number of unmasked tokens in the current batch - # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() - num_tokens += current_num_tokens - - with self.train_context( - self.context_parallel_manager(list(batch.values())) - ): - # Loss is normalized by default so we multiply by the number of tokens - # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_step(batch) * current_num_tokens - running_loss += current_loss - # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive - if self._optimizer_in_bwd: - torch.distributed.all_reduce(num_tokens) - torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) - current_loss.backward() - - # Optimizer step (if not fused in backward call) - if (batch_count + 1) % self._gradient_accumulation_steps == 0: - if not self._optimizer_in_bwd: - # Get total number of tokens across all ranks to normalize gradients - torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing - torch.distributed.all_reduce(running_loss) - - # Manually scale the gradients from unnormalized loss by total # of tokens - self._grad_scaler( - list(self._model.parameters()), - self.world_size / num_tokens, - False if self.parallel_dims.tp_enabled else None, - ) - - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) - # If sharded, collect the DTensor here - if isinstance(grad_norm, DTensor): - grad_norm = grad_norm.full_tensor() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - - # Step the learning rate scheduler - if self._lr_scheduler is not None: - self._lr_scheduler.step() - - self.global_step += 1 - inner_step_count += 1 - - # If float8 training is enabled, perform a single all-reduce to compute the - # scale for all float8 parameters efficiently instead of doing many small - # all-reduces for each parameter - if ( - self._enable_fp8_training - and is_fp8_tensorwise_scaling(self._fp8_recipe_name) - and self.dp_degree > 1 - ): - precompute_float8_dynamic_scale_for_fsdp(self._model) - - loss_to_log = running_loss.detach().item() / num_tokens - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + + pbar = tqdm(initial=self.global_step, total=self.num_training_steps, desc="Training") + + dataloader_iter = iter(self._dataloader) + batch_count = 0 + + while self.global_step < self.num_training_steps: + try: + batch = next(dataloader_iter) + except StopIteration: + self._logger.warning("Dataloader iterator exhausted unexpectedly. Ending training.") + break + + if "metrics" in batch: + self._metrics_aggregator.update(batch.pop("metrics")) + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count + == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + with self.train_context( + self.context_parallel_manager(list(batch.values())) + ): + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss * (self.dp_degree / num_tokens) + current_loss.backward() + + # Optimizer step (if not fused in backward call) + if (batch_count + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + + # Manually scale the gradients from unnormalized loss by total # of tokens + self._grad_scaler( + list(self._model.parameters()), + self.world_size / num_tokens, + False if self.parallel_dims.tp_enabled else None, ) - # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - ), - "tokens_per_second_per_gpu": ( - num_tokens / self.parallel_dims.non_data_parallel_size - ) - / (time_per_step * self.world_size), - } - if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), ) - - # Save checkpoint if specified by user - if self.global_step % self.save_every_n_steps == 0: - self.save_checkpoint(epoch=curr_epoch, full_tensors=False) - - # Reset running stats for the next step - running_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + self.global_step += 1 + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - and self._device.type == "cuda" + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 ): - torch.cuda.memory._record_memory_history(enabled=None) - - self._profiler.step() - batch_count += 1 - - # Run validation after gradient update - if ( - self._run_val_every_n_steps is not None - and self.global_step % self._run_val_every_n_steps == 0 - ): - pbar.refresh() - self.validate() - - self.epochs_run += 1 + precompute_float8_dynamic_scale_for_fsdp(self._model) + + loss_to_log = running_loss.detach().item() / num_tokens + pbar.update(1) + pbar.set_description(f"Step: {self.global_step}|Loss: {loss_to_log:.4f}") + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0 and self._is_rank_zero: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), + "tokens_per_second_per_gpu": (num_tokens / self.parallel_dims.non_data_parallel_size) / (time_per_step * self.world_size), + } + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict(log_dict, step=self.global_step) + + # Log dataset metrics + # #TODO: it requires all_gather. Should we keep a separate log_freq for this? + if self.global_step % self._dataset_metrics_log_freq == 0 and self._is_rank_zero: + dataset_metrics = self._metrics_aggregator.get_metrics_for_logging(prefix="train") + self._metric_logger.log_dict(dataset_metrics, step=self.global_step) + + # Save checkpoint if specified by user + if self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0: + self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=False) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + self._profiler.step() + batch_count += 1 + + # Run validation after gradient update + if ( + self._run_val_every_n_steps is not None + and self.global_step % self._run_val_every_n_steps == 0 + ): + pbar.refresh() + self.validate() self._profiler.stop() - self.save_checkpoint(epoch=curr_epoch, full_tensors=True) + self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=True) def cleanup(self) -> None: if self._is_rank_zero: diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py new file mode 100644 index 0000000000..382d968704 --- /dev/null +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import pytest +from unittest.mock import patch + +from torchtune.data import AggregationType, Metric, MetricsAggregator + + +class TestMetricsAggregator: + """Focused tests for MetricsAggregator functionality.""" + + @pytest.mark.parametrize( + "agg_type,test_values,expected", + [ + (AggregationType.SUM, [1, 2, 3, 4], 10), + (AggregationType.MEAN, [10, 20, 30, 40], 25.0), + (AggregationType.MAX, [-5, 10, 3, 15], 15), + (AggregationType.MIN, [5, -2, 8, 1], -2), + ( + AggregationType.CATEGORICAL_COUNT, + ["A", "B", "A", "C", "A"], + {"A": 3, "B": 1, "C": 1}, + ), + ], + ) + def test_aggregation_types(self, agg_type, test_values, expected): + """Tests each `AggregationType` to ensure it computes the correct value.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric(dataset_name="test", name="metric", value=val, agg_type=agg_type) + for val in test_values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging() + + if agg_type == AggregationType.CATEGORICAL_COUNT: + for category, count in expected.items(): + assert result[f"test/metric_{category}_count"] == count + else: + assert result["test/metric"] == expected + + def test_distribution_metrics(self): + """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" + aggregator = MetricsAggregator() + values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + metrics = [ + Metric("test", "dist_metric", val, AggregationType.DISTRIBUTION) + for val in values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify distribution statistics + assert result["train/test/dist_metric_mean"] == 5.5 + assert result["train/test/dist_metric_min"] == 1 + assert result["train/test/dist_metric_max"] == 10 + assert result["train/test/dist_metric_p50"] == 5 # Median of 1-10 is 5 (index 4, value 5) + + def test_state_management(self): + """Test aggregator checkpointing and restoration.""" + # Create aggregator with some state + aggregator1 = MetricsAggregator() + initial_metrics = [ + Metric("ds1", "counter", 10, AggregationType.SUM), + Metric("ds1", "average", 5.0, AggregationType.MEAN), + Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT), + ] + aggregator1.update(initial_metrics) + + # Save state + state = aggregator1.state_dict() + + # Create new aggregator and restore state + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state) + + # Both should have identical metrics + metrics1 = aggregator1.get_metrics_for_logging() + metrics2 = aggregator2.get_metrics_for_logging() + assert metrics1 == metrics2 + + # Continue updating both - should remain identical + additional_metrics = [ + Metric("ds1", "counter", 5, AggregationType.SUM), + Metric("ds1", "average", 15.0, AggregationType.MEAN), + ] + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + final_metrics1 = aggregator1.get_metrics_for_logging() + final_metrics2 = aggregator2.get_metrics_for_logging() + assert final_metrics1 == final_metrics2 + + # Verify expected values + assert final_metrics1["ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["ds1/average"] == 10.0 # (5 + 15) / 2 + + def test_multiple_datasets(self): + """Test that metrics from multiple datasets are correctly namespaced.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric("dataset1", "samples", 100, AggregationType.SUM), + Metric("dataset2", "samples", 200, AggregationType.SUM), + Metric("dataset1", "tokens", 1000, AggregationType.SUM), + Metric("dataset2", "tokens", 2000, AggregationType.SUM), + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + assert result["train/dataset1/samples"] == 100 + assert result["train/dataset2/samples"] == 200 + assert result["train/dataset1/tokens"] == 1000 + assert result["train/dataset2/tokens"] == 2000 + + def test_empty_aggregator(self): + """Test that empty aggregator returns empty metrics.""" + aggregator = MetricsAggregator() + result = aggregator.get_metrics_for_logging() + assert result == {} + + def test_prefix_handling(self): + """Test that prefix is correctly applied to metric keys.""" + aggregator = MetricsAggregator() + metrics = [ + Metric("test_ds", "metric1", 42, AggregationType.SUM), + Metric("test_ds", "metric2", 84, AggregationType.SUM), + ] + aggregator.update(metrics) + + # Test with prefix + result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") + assert result_with_prefix["validation/test_ds/metric1"] == 42 + assert result_with_prefix["validation/test_ds/metric2"] == 84 + + # Test without prefix + result_no_prefix = aggregator.get_metrics_for_logging() + assert result_no_prefix["test_ds/metric1"] == 42 + assert result_no_prefix["test_ds/metric2"] == 84 \ No newline at end of file diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py new file mode 100644 index 0000000000..1eed534e42 --- /dev/null +++ b/tests/torchtune/data/test_metrics_transform.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +from torchtune.data import AggregationType, Metric, StandardMetricTransform + + +class TestStandardMetricTransform: + """Tests for StandardMetricTransform functionality.""" + + def test_dataset_name_not_set_raises_error(self): + """Test that using transform without setting dataset name raises error.""" + transform = StandardMetricTransform() + sample = {"tokens": [1, 2, 3]} + + with pytest.raises(RuntimeError, match="set_dataset_name"): + transform(sample) + + def test_basic_metrics_generation(self): + """Test that transform generates expected metrics for a sample.""" + transform = StandardMetricTransform() + transform.set_dataset_name("test_dataset") + + sample = {"tokens": [1, 2, 3, 4, 5]} + result = transform(sample) + + # Should preserve original sample data + assert result["tokens"] == [1, 2, 3, 4, 5] + + # Should add metrics + assert "metrics" in result + metrics = result["metrics"] + assert len(metrics) == 3 + + # Check each metric + for metric in metrics: + if metric.name == "samples_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 1 + assert metric.agg_type == AggregationType.SUM + + elif metric.name == "tokens_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.SUM + + elif metric.name == "seq_len": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.DISTRIBUTION \ No newline at end of file diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py new file mode 100644 index 0000000000..4cf303c6fd --- /dev/null +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import tempfile +from pathlib import Path +from itertools import islice +from typing import Any, Callable, Dict, List, Optional +from unittest.mock import Mock, patch + +import pytest +import torch +from torch.nn.utils.rnn import pad_sequence +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform, padded_collate_sft +from torchtune.datasets import HfIterableDataset + + +# Test Constants - Avoid perfect divisions +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +SEED = 42 +BATCH_SIZE = 5 +DEFAULT_SHUFFLE_BUFFER_SIZE = 8 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + ) + + +def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" + # Extract metrics first + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Use torchtune's padded_collate_sft as base collator + collated_batch = padded_collate_sft(clean_batch) + collated_batch["metrics"] = all_metrics + return collated_batch + + +def generate_ckpt( + dataloader: StatefulDataLoader, + aggregator: MetricsAggregator, + steps_before_checkpoint: int, + steps_after_checkpoint: int, + resume_dataloader: Optional[StatefulDataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, +) -> Dict[str, Any]: + """ + Generates a checkpoint by running through data and saving checkpoint mid-stream. + Optionally, a second dataloader and aggregator can be given to resume from ckpt + and run steps_after_checkpoint to match the first one. + + Args: + dataloader: The dataloader to test + aggregator: The metrics aggregator to use + steps_before_checkpoint: Number of steps to run before saving checkpoint + steps_after_checkpoint: Number of steps to run after checkpoint + resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. + resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. + + Returns dict with batches/metrics from both pre and post checkpoint runs. + """ + iterator = iter(dataloader) + + # Collect batches before and after checkpoint + batches = [] + checkpoint_state = None + metrics_at_checkpoint = {} + + total_steps = steps_before_checkpoint + steps_after_checkpoint + + for idx, batch in enumerate(iterator): + batches.append(batch) + + # Process metrics + if "metrics" in batch: + aggregator.update(batch.pop("metrics")) + + # Save checkpoint state after steps_before_checkpoint + if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based + checkpoint_state = { + "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), + } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") + + # Stop after total steps + if idx == total_steps - 1: + break + + # Split batches + pre_checkpoint_batches = batches[:steps_before_checkpoint] + post_checkpoint_batches = batches[steps_before_checkpoint:] + + # Resume with new instances if provided + resumed_batches = [] + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances + resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) + resume_iterator = iter(resume_dataloader) + + # Collect only the post-checkpoint batches when resuming + for idx, batch in enumerate(resume_iterator): + resumed_batches.append(batch) + + # Process metrics + if "metrics" in batch: + resume_aggregator.update(batch.pop("metrics")) + + # Stop after steps_after_checkpoint + if idx == steps_after_checkpoint - 1: + break + + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + + return { + # Original run + "pre_checkpoint_batches": pre_checkpoint_batches, + "post_checkpoint_batches": post_checkpoint_batches, + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + # Resumed run + "resumed_batches": resumed_batches, + "resumed_metrics": resumed_metrics, + # Internal state for loading - only if someone needs to manually load + "_checkpoint_state": checkpoint_state, + } + + +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + **kwargs + ) + return _create_dataset + + +class TestHfIterableDataset: + """Tests for HfIterableDataset basic functionality.""" + + def test_default_dataset_name(self, small_dataset_file): + """Test that dataset name is auto-generated from path when not provided.""" + # Create dataset without specifying name + dataset = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + # dataset_name not provided - should auto-generate + seed=SEED, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=4, + ) + + # Should generate name from path and split + assert dataset.dataset_name == "json_train" + + # Test giving a name + dataset2 = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + dataset_name = "my_dataset", + seed=SEED, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=4, + ) + + # Should generate name from path and split + assert dataset2.dataset_name == "my_dataset" + + @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) + def test_epoch_boundaries_and_checkpointing( + self, num_epochs, dataset_factory, small_dataset_file + ): + """ + Tests that for N epochs, each sample appears exactly N times (rounded down), + the epoch metric is correct, and checkpointing works as expected. + """ + + # 1. Setup Dataloaders and Aggregators for original and resumed runs + def create_loader_and_aggregator(): + dataset = dataset_factory(small_dataset_file, shuffle=False) + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator = MetricsAggregator() + return loader, aggregator + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # 2. Calculate steps for the test run + total_samples = int(SMALL_DATASET_SIZE * num_epochs) + total_steps = total_samples // BATCH_SIZE + + steps_before_checkpoint = max(1, total_steps // 2) + steps_after_checkpoint = total_steps - steps_before_checkpoint + + # 3. Generate checkpoint and resume + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=steps_before_checkpoint, + steps_after_checkpoint=steps_after_checkpoint, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + # 4. Verify checkpointing and resumption + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + def test_shuffling_behavior(self, dataset_factory, small_dataset_file): + """Tests that shuffling changes data order between epochs but preserves the set of samples.""" + # Test unshuffled dataset + unshuffled_ds = dataset_factory( + small_dataset_file, dataset_name="unshuffled", shuffle=False + ) + + # Get samples from two passes through the dataset + epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Unshuffled should have same order in both epochs + assert first_epoch_samples == list(range(SMALL_DATASET_SIZE)) + assert second_epoch_samples == list(range(SMALL_DATASET_SIZE)) + + # Test shuffled dataset + shuffled_ds = dataset_factory( + small_dataset_file, dataset_name="shuffled", shuffle=True + ) + + # Collect full epochs to compare + epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Shuffled epochs should have different order + assert first_epoch_samples != list( + range(SMALL_DATASET_SIZE) + ), f"Shuffled should not be sorted, got {first_epoch_samples}" + assert ( + first_epoch_samples != second_epoch_samples + ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" + + # But should contain the same set of IDs + assert set(first_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" + assert set(second_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + + def test_epoch_tracking(self, dataset_factory, small_dataset_file): + """Test that epoch number is correctly tracked across dataset restarts.""" + dataset = dataset_factory(small_dataset_file, shuffle=False) + + # Two epoch samples + epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # All should have epoch 0 + epoch_values = [epoch_metric.value for epoch_metric in first_epoch_samples["metrics"]] + assert all(epoch_value == 0 for epoch_value in epoch_values), f"Epoch values should be 0, got {epoch_values}" + + # All should have epoch 1 + epoch_values = [epoch_metric.value for epoch_metric in second_epoch_samples["metrics"]] + assert all(epoch_value == 1 for epoch_value in epoch_values), f"Epoch values should be 1, got {epoch_values}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py new file mode 100644 index 0000000000..1190d6d774 --- /dev/null +++ b/tests/torchtune/datasets/test_interleaved.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import islice +from typing import Any, Dict, Iterator +from unittest.mock import patch + +import pytest +import torch + +from torchtune.data import AggregationType, Metric, MetricsAggregator +from torchtune.datasets import InterleavedDataset, TuneIterableDataset + + +class TestInterleavedDataset: + """Tests for multi-dataset interleaving functionality.""" + + def test_initialization_validation(self, dataset_factory, small_dataset_file): + """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" + # Test duplicate dataset names + ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate") + ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate") + + with pytest.raises(ValueError, match="Duplicate dataset names detected"): + InterleavedDataset(datasets=[ds1, ds2], weights=[0.5, 0.5], seed=SEED) + + # Test weight normalization (should work with warning) + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3") + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4") + + with patch("logging.Logger.warning") as mock_warning: + interleaved = InterleavedDataset( + datasets=[ds3, ds4], + weights=[0.5, 1.5], + seed=SEED, + dataset_name="test_interleaved" # Sum = 2.0 != 1.0 + ) + + # Check that weights were normalized + assert torch.allclose(interleaved._weights, torch.tensor([0.25, 0.75])) + mock_warning.assert_called_once() + + assert interleaved.dataset_name == "test_interleaved" + + def test_sampling_ratios( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that datasets are sampled according to their assigned weights.""" + # Create two datasets with distinct ID ranges + # ds1 has IDs 0-22 (small dataset) + # ds2 has IDs 100-134 (medium dataset with offset) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + + # Test with 70/30 weighting + weights = [0.7, 0.3] + interleaved = InterleavedDataset([ds1, ds2], weights, seed=SEED) + + # Collect 300 samples + sample_count = 300 + samples = list(islice(iter(interleaved), sample_count)) + + # Count samples by checking ID ranges + # ds1 has IDs < 100, ds2 has IDs >= 100 + ds1_count = sum(1 for s in samples if s["id"] < 100) + ds2_count = sum(1 for s in samples if s["id"] >= 100) + + assert ds1_count + ds2_count == sample_count + + # Check ratios are approximately correct + ds1_ratio = ds1_count / sample_count + ds2_ratio = ds2_count / sample_count + + # Allow 10% tolerance due to randomness + assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" + assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + + def test_metrics_aggregation( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that metrics from all child datasets are collected and aggregated.""" + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + + interleaved = InterleavedDataset([ds1, ds2], [0.2, 0.8], seed=SEED) + aggregator = MetricsAggregator() + + # Process some samples + TOTAL_SAMPLES = 200 + for sample in islice(iter(interleaved), 200): + aggregator.update(sample["metrics"]) + + metrics = aggregator.get_metrics_for_logging() + + # Should have metrics from both datasets, with flat keys + assert "ds1/samples_seen" in metrics + assert "ds2/samples_seen" in metrics + + # Both datasets should have contributed samples + assert metrics["ds1/samples_seen"] > 0 + assert metrics["ds2/samples_seen"] > 0 + + # Total samples should equal what we processed + calculated_total_samples = ( + metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] + ) + assert calculated_total_samples == TOTAL_SAMPLES + + # Test that ratio is approximately correct + ds1_ratio = metrics["ds1/samples_seen"] / TOTAL_SAMPLES + ds2_ratio = metrics["ds2/samples_seen"] / TOTAL_SAMPLES + + # Allow 10% tolerance due to randomness + assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" + assert abs(ds2_ratio - 0.8) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.8" + + def test_checkpointing( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that interleaved dataset checkpointing preserves sampling state.""" + + def create_interleaved(): + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + return InterleavedDataset([ds1, ds2], [0.7, 0.3], seed=SEED) + + # Original run + interleaved1 = create_interleaved() + loader1 = StatefulDataLoader( + interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator1 = MetricsAggregator() + + # Resumed run + interleaved2 = create_interleaved() + loader2 = StatefulDataLoader( + interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator2 = MetricsAggregator() + + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=10, + steps_after_checkpoint=20, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + \ No newline at end of file diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index a75e16780a..e1d7d687dd 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -32,11 +32,23 @@ QuestionAnswerTemplate, SummarizeTemplate, ) +from torchtune.data._metrics import ( + AggregationType, + Metric, + MetricTransform, + StandardMetricTransform, +) from torchtune.data._utils import format_content_with_images, load_image, truncate +from torchtune.data._aggregator import MetricsAggregator __all__ = [ + "AggregationType", "CROSS_ENTROPY_IGNORE_IDX", "GrammarErrorCorrectionTemplate", + "Metric", + "MetricsAggregator", + "MetricTransform", + "StandardMetricTransform", "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py new file mode 100644 index 0000000000..f6b962c84c --- /dev/null +++ b/torchtune/data/_aggregator.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import logging +from typing import Any, Dict, List, Tuple + +import torch +import torch.distributed as dist + +from torchtune.data._metrics import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +class MetricsAggregator: + """ + Aggregates metrics across datasets and distributed ranks. + + The internal state `_state` is a dictionary where the key is a tuple + of `(dataset_name, metric_name)` and the value is another dictionary + holding the metric's specific state (e.g., `{'type': AggregationType.SUM, 'value': 10}`). + + Usage: + aggregator = MetricsAggregator() + aggregator.update(metrics) + # Get logger-ready metrics {key: value} + metrics = aggregator.get_metrics_for_logging(prefix="train") # {"train/dataset1/tokens": 1234, ...} + """ + + def __init__(self, dist_window_size: int = 1000): + # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} + self._state: Dict[Tuple[str, str], Dict[str, Any]] = {} + + # For distributions, we keep a window of values to compute percentiles + self._dist_window_size = dist_window_size + + def update(self, metrics: List[Metric]) -> None: + """Update internal state with new metrics. + + Args: + metrics: List of Metric objects + """ + for metric in metrics: + key = (metric.dataset_name, metric.name) + + if key not in self._state: + self._initialize_state(key, metric.agg_type) + + state = self._state[key] + + # Update based on aggregation type + if metric.agg_type == AggregationType.SUM: + state["value"] += metric.value + elif metric.agg_type == AggregationType.MAX: + if state["value"] is not None: + state["value"] = max(state["value"], metric.value) + else: + state["value"] = metric.value + elif metric.agg_type == AggregationType.MIN: + if state["value"] is not None: + state["value"] = min(state["value"], metric.value) + else: + state["value"] = metric.value + elif metric.agg_type == AggregationType.MEAN: + state["sum"] += metric.value + state["count"] += 1 + elif metric.agg_type == AggregationType.DISTRIBUTION: + state["values"].append(metric.value) + elif metric.agg_type == AggregationType.CATEGORICAL_COUNT: + state["counts"][metric.value] += 1 + + def _initialize_state( + self, key: Tuple[str, str], agg_type: AggregationType + ) -> None: + """Initialize state for a new metric.""" + self._state[key] = {"type": agg_type} + state = self._state[key] + + if agg_type == AggregationType.SUM: + state["value"] = 0.0 + elif agg_type in (AggregationType.MAX, AggregationType.MIN): + state["value"] = None + elif agg_type == AggregationType.MEAN: + state["sum"] = 0.0 + state["count"] = 0 + elif agg_type == AggregationType.DISTRIBUTION: + state["values"] = collections.deque(maxlen=self._dist_window_size) + elif agg_type == AggregationType.CATEGORICAL_COUNT: + state["counts"] = collections.Counter() + + def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: + """ + Returns aggregated metrics ready for logging to wandb/tensorboard. + + Args: + prefix: Optional prefix like "train" or "valid" for metric keys + + Returns: + Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value + Ready to be logged directly: wandb.log(metrics) + """ + # Always compute local metrics first + local_metrics = self._compute_local_metrics() + + # In distributed mode, perform reduction + if dist.is_initialized() and dist.get_world_size() > 1: + metrics = self._compute_distributed_metrics(local_metrics) + else: + metrics = local_metrics + + # Format for logging with proper key structure + return self._format_for_logging(metrics, prefix) + + def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: + """ + Compute metrics from current state. + + For distributions and categoricals, expands into multiple entries. + The dict format allows future extensions with additional fields. + + Args: + None + + Returns: + Dictionary mapping (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} + """ + metrics = {} + + for (ds_name, metric_name), state in self._state.items(): + agg_type = state["type"] + + if agg_type in ( + AggregationType.SUM, + AggregationType.MAX, + AggregationType.MIN, + ): + # For sum, max, and min, we just need to return the value + metrics[(ds_name, metric_name)] = { + "value": state["value"], + "agg_type": agg_type, + } + + elif agg_type == AggregationType.MEAN: + if state["count"] > 0: + value = state["sum"] / state["count"] + metrics[(ds_name, metric_name)] = { + "value": value, + "agg_type": agg_type, + } + + elif agg_type == AggregationType.DISTRIBUTION: + # queue -> list + values = list(state["values"]) + + # Sort to get percentiles efficiently + sorted_values = sorted(values) + n = len(sorted_values) + + # Each stat becomes its own metric + # For percentiles, it is an approximattion by computing avg of averages + metrics[(ds_name, f"{metric_name}_mean")] = { + "value": sum(values) / n, + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_min")] = { + "value": sorted_values[0], + "agg_type": AggregationType.MIN, + } + metrics[(ds_name, f"{metric_name}_max")] = { + "value": sorted_values[-1], + "agg_type": AggregationType.MAX, + } + metrics[(ds_name, f"{metric_name}_p05")] = { + "value": sorted_values[max(0, int(0.05 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_p50")] = { + "value": sorted_values[max(0, int(0.5 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_p95")] = { + "value": sorted_values[max(0, int(0.95 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + + elif agg_type == AggregationType.CATEGORICAL_COUNT: + # Expand categorical counts into individual metrics + for category, count in state["counts"].items(): + metrics[(ds_name, f"{metric_name}_{category}_count")] = { + "value": count, + "agg_type": AggregationType.SUM, + } + + return metrics + + def _compute_distributed_metrics( + self, local_metrics: Dict[Tuple[str, str], Dict[str, Any]] + ) -> Dict[Tuple[str, str], Dict[str, Any]]: + """ + Performs distributed reduction on metrics. + + Strategy: + 1. Do a single all_gather_object to collect all metrics from all ranks + 2. Group metrics by key and aggregation type + 3. Apply the appropriate reduction operation locally + + This avoids complex tensor operations and handles all reduction in one pass. + + Args: + local_metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + + Returns: + Reduced metrics in same format as input + + Example: + rank_1_metrics = + { + ("ds1", "metric1"): {"value": 10, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 20, "agg_type": AggregationType.MEAN}, + } + rank_2_metrics = + { + ("ds1", "metric1"): {"value": 30, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 40, "agg_type": AggregationType.MEAN}, + } + + # After reduction + result = + { + ("ds1", "metric1"): {"value": 40, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 30, "agg_type": AggregationType.MEAN}, + } + """ + world_size = dist.get_world_size() + + # Gather all metrics from all ranks in one operation + dist.barrier() + all_metrics = [None] * world_size + dist.all_gather_object(all_metrics, local_metrics) + + # Group values by key for reduction + grouped = collections.defaultdict(list) + for rank_metrics in all_metrics: + if rank_metrics: # It's possible a rank has no metrics + for key, metric_dict in rank_metrics.items(): + # A key is a tuple (dataset, metric) + grouped[key].append(metric_dict) + + # Reduce based on aggregation type + reduced = {} + if not grouped: + return reduced + + for key, metric_dicts in grouped.items(): + # All metrics for a key should have same type, just take first + values = [m["value"] for m in metric_dicts] + agg_type = metric_dicts[0]["agg_type"] + + # Start with copy of first dict to preserve any extra fields + result_dict = metric_dicts[0].copy() + + if agg_type == AggregationType.SUM: + result_dict["value"] = sum(values) + elif agg_type == AggregationType.MAX: + result_dict["value"] = max(values) + elif agg_type == AggregationType.MIN: + result_dict["value"] = min(values) + elif agg_type == AggregationType.MEAN: + result_dict["value"] = sum(values) / len(values) + + reduced[key] = result_dict + + return reduced + + def _format_for_logging( + self, metrics: Dict[Tuple[str, str], Dict[str, Any]], prefix: str + ) -> Dict[str, float]: + """ + Format metrics for wandb/tensorboard logging. + + Args: + metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + prefix: Optional prefix like "train" or "valid" + + Returns: + Flat dict with string keys like "train/dataset1/tokens_seen" -> float + """ + formatted = {} + + for (ds_name, metric_name), metric_dict in metrics.items(): + # Build key: "prefix/dataset/metric" or "dataset/metric" if no prefix + if prefix: + key = f"{prefix}/{ds_name}/{metric_name}" + else: + key = f"{ds_name}/{metric_name}" + + formatted[key] = metric_dict["value"] + + return formatted + + def state_dict(self) -> Dict[str, Any]: + """Serialize aggregator state. The state is almost directly serializable.""" + serializable_state = {} + for key, state in self._state.items(): + state_copy = state.copy() + + # Convert non-serializable types + if "values" in state_copy: + state_copy["values"] = list(state_copy["values"]) # deque → list + if "counts" in state_copy: + state_copy["counts"] = dict(state_copy["counts"]) # Counter → dict + + # Convert tuple key to string for JSON compatibility + # JSON doesn't support tuple keys, so we convert (dataset, metric) → "('dataset', 'metric')" + serializable_state[str(key)] = state_copy + return {"state": serializable_state, "dist_window_size": self._dist_window_size} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load aggregator state from checkpoint.""" + self._dist_window_size = state_dict["dist_window_size"] + + deserialized_state = {} + for key_str, state in state_dict["state"].items(): + # Convert string keys back to tuples + # "('dataset', 'metric')" → ('dataset', 'metric') + key = ast.literal_eval(key_str) + + # Re-wrap values in their original types + if state.get("type") == AggregationType.DISTRIBUTION: + state["values"] = collections.deque( + state["values"], maxlen=self._dist_window_size + ) + if state.get("type") == AggregationType.CATEGORICAL_COUNT: + state["counts"] = collections.Counter(state["counts"]) + + deserialized_state[key] = state + self._state = deserialized_state \ No newline at end of file diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py new file mode 100644 index 0000000000..f61d0e579e --- /dev/null +++ b/torchtune/data/_metrics.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, Optional, Protocol, Union + + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated.""" + + SUM = "sum" + MEAN = "mean" + DISTRIBUTION = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + + +@dataclass(frozen=True) +class Metric: + """A self-describing metric object.""" + + dataset_name: str + name: str + value: Union[int, float, str] + agg_type: AggregationType + + +class MetricTransform(Protocol): + """Protocol for metric transforms.""" + + def set_dataset_name(self, dataset_name: str) -> None: ... + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: ... + + +class StandardMetricTransform(MetricTransform): + """ + Attaches per-sample metrics for tracking training progress. + + This transform is responsible for generating metrics on a per-sample + basis (e.g., tokens per sample). The actual aggregation of these metrics + (eg calculating sum of samples seen) is handled by the + `MetricsAggregator`. This separation of concerns ensures that metrics are + correctly aggregated even with multiple dataloader workers and in a + distributed setting. + + Tracked metrics include: + - samples_seen: A count of samples processed. + - tokens_seen: The cumulative sum of all tokens processed. + - seq_len: A distribution of sequence lengths. + """ + + def __init__(self): + # dataset_name is set by the dataset using set_dataset_name + self.dataset_name: Optional[str] = None + self.new_metric: Optional[Callable] = None + + def set_dataset_name(self, dataset_name: str) -> None: + """Called by dataset to set the namespace for metrics. + The dataset name is used to differentiate multiple datasets stats, + e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen".""" + self.dataset_name = dataset_name + self.new_metric = partial(Metric, dataset_name=dataset_name) + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + if self.dataset_name is None or self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + metrics = [ + self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), + self.new_metric( + name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + ), + self.new_metric( + name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + ), + ] + + # Append to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample \ No newline at end of file diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index b0c7c11738..4ea863169d 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -5,18 +5,25 @@ # LICENSE file in the root directory of this source tree. from torchtune.datasets import multimodal -from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset +from torchtune.datasets._alpaca import ( + alpaca_cleaned_dataset, + alpaca_dataset, + alpaca_iterable_dataset, +) from torchtune.datasets._chat import chat_dataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset +from torchtune.datasets._interleaved import InterleavedDataset +from torchtune.datasets._iterable_base import TuneIterableDataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset -from torchtune.datasets._sft import SFTDataset -from torchtune.datasets._slimorca import slimorca_dataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._slimorca import slimorca_dataset, slimorca_iterable_dataset from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( text_completion_dataset, @@ -25,23 +32,29 @@ from torchtune.datasets._wikitext import wikitext_dataset __all__ = [ - "alpaca_dataset", "alpaca_cleaned_dataset", + "alpaca_dataset", + "alpaca_iterable_dataset", + "chat_dataset", + "cnn_dailymail_articles_dataset", + "ConcatDataset", "grammar_dataset", - "samsum_dataset", - "stack_exchange_paired_dataset", - "slimorca_dataset", + "hh_rlhf_helpful_dataset", + "HfIterableDataset", "instruct_dataset", + "InterleavedDataset", + "multimodal", + "PackedDataset", "preference_dataset", - "chat_dataset", + "PreferenceDataset", + "samsum_dataset", + "SFTDataset", + "sft_iterable_dataset", + "slimorca_dataset", + "slimorca_iterable_dataset", + "stack_exchange_paired_dataset", "text_completion_dataset", "TextCompletionDataset", - "cnn_dailymail_articles_dataset", - "PackedDataset", - "ConcatDataset", + "TuneIterableDataset", "wikitext_dataset", - "PreferenceDataset", - "SFTDataset", - "hh_rlhf_helpful_dataset", - "multimodal", ] diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 1ecee62f53..4225ab4bf5 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -9,9 +9,11 @@ from typing import Any, Callable, Optional, Union from torchtune.data._messages import AlpacaToMessages +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -101,3 +103,64 @@ def alpaca_dataset( original Alpaca dataset, `yahma/alpaca-cleaned `_. See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. """ + + +def alpaca_iterable_dataset( + model_transform: ModelTokenizer, + *, + source: str = "tatsu-lab/alpaca", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = True, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + split: str = "train", + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for iterable version of Alpaca-style datasets. + + This returns an infinite iterable dataset that supports checkpointing + and metrics tracking, designed for step-based training. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + source (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. + column_map (Optional[dict[str, str]]): a mapping from the expected columns in the message transform + :class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be + "instruction", "input", and "output" and values should be the actual column names. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. + seed (int): Seed for shuffling. + dataset_name (Optional[str]): Name of the dataset for metrics tracking. If None, auto-generated. + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. + split (str): ``split`` argument for ``datasets.load_dataset``. Default is "train". + **load_dataset_kwargs (dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Returns: + HfIterableDataset: iterable dataset configured with source data and transforms + + Example: + >>> from torchdata.stateful_dataloader import StatefulDataLoader + >>> alpaca_ds = alpaca_iterable_dataset(tokenizer=tokenizer) + >>> dataloader = StatefulDataLoader(alpaca_ds, batch_size=8) + >>> for batch in dataloader: + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + message_transform = AlpacaToMessages( + train_on_input=train_on_input, column_map=column_map + ) + + return sft_iterable_dataset( + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + dataset_name=dataset_name, + filter_fn=filter_fn, + split=split, + path=source, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py new file mode 100644 index 0000000000..9a206445d5 --- /dev/null +++ b/torchtune/datasets/_hf_iterable.py @@ -0,0 +1,271 @@ +import logging +from typing import Any, Callable, Dict, Iterator, List, Optional + +import torch +import torch.distributed as dist +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node + +from torchtune.data._metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.datasets._iterable_base import TuneIterableDataset + +logger = logging.getLogger(__name__) + + +class HfIterableDataset(TuneIterableDataset): + """HuggingFace dataset implementation with composable metrics. + + This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. + + This dataset is responsible for: + - Loading and sharding the dataset + - Shuffling at initialization and after each epoch + - Applying transforms + - Returning an infinite iterator over the dataset + + Args: + message_transform (Optional[Callable]): Transforms raw data into Message + model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. + output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually + does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. + metric_transform (Optional[Callable]): Takes the sample and computes metrics, e.g. token count. + If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. + seed (int): Seed for shuffling. + num_shards_per_rank (int): Target number of shards per worker (GPU). It will find a multiple + of world_size * dataloader_workers. + dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated + from the path, source, and split. + filter_fn (Optional[Callable]): Filter function to apply to the dataset. + filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to the filter function. + load_dataset_kwargs (Dict[str, Any]): Keyword arguments to pass to the load_dataset function. + + """ + + def __init__( + self, + *, + message_transform: Optional[Callable] = None, + model_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + metric_transform: Optional[Callable] = None, + shuffle_buffer_size: Optional[int] = 1000, + weight: Optional[float] = 1.0, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + **load_dataset_kwargs, + ): + # Store configuration + self._shuffle_buffer_size = shuffle_buffer_size + self._seed = seed + self._message_transform = message_transform + self._model_transform = model_transform + self._output_transform = output_transform + self._weight = weight # TODO: make it a property? + + # Create default transform if not provided + self._metric_transform = metric_transform or StandardMetricTransform() + + # Auto-generate dataset name if not provided, ensuring it's always a string. + if dataset_name is None: + path = load_dataset_kwargs.get("path", None) + source = load_dataset_kwargs.get("source", None) + split = load_dataset_kwargs.get("split", None) + name_parts = [] + for item in [path, source, split]: + if item is not None: + name_parts.append(str(item).replace("/", "_")) + self._dataset_name: str = "_".join(name_parts) + else: + self._dataset_name: str = dataset_name + + # Set dataset name on the transform if it supports it + if hasattr(self._metric_transform, "set_dataset_name"): + self._metric_transform.set_dataset_name(self._dataset_name) + + # Internal state for resumption + self._num_epochs = 0 + + # Load and setup HF dataset + self._setup_hf_dataset( + load_dataset_kwargs, num_shards_per_rank, filter_fn, filter_kwargs + ) + + @property + def dataset_name(self) -> str: + return self._dataset_name + + def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Apply transforms if they exist, otherwise return sample unchanged.""" + if self._message_transform is not None: + sample = self._message_transform(sample) + if self._model_transform is not None: + sample = self._model_transform(sample) + if self._output_transform is not None: + sample = self._output_transform(sample) + if self._metric_transform is not None: + sample = self._metric_transform(sample) + return sample + + def _setup_hf_dataset( + self, + load_dataset_kwargs: Dict[str, Any], + num_shards_per_rank: int, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Configures the Hugging Face dataset, including sharding, filtering, and + transform mapping. This method is called only once during initialization + to avoid expensive re-computation on each epoch. + """ + + # Distributed setup + world_size, rank = 1, 0 + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Load and shard dataset + ds = load_dataset(**load_dataset_kwargs) + + # Use to_iterable_dataset for streaming datasets + if not load_dataset_kwargs.get("streaming", False): + + # Define number of shards based on (world_size, num of shards per GPU, dataloader workers) + # E.g. world_size=2, num_shards_per_rank=16, dataloader_workers=3 + # we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. + # Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. + worker_info = torch.utils.data.get_worker_info() + num_dataloader_workers = worker_info.num_workers if worker_info else 1 + + # Calculate total workers + total_workers = world_size * num_dataloader_workers + + # Calculate desired shards + desired_shards = world_size * num_shards_per_rank + + # Find the smallest multiple of total_workers that is >= desired_shards + if desired_shards % total_workers == 0: + num_shards = desired_shards + else: + num_shards = total_workers * ( + (desired_shards + total_workers - 1) // total_workers + ) + + # If the dataset is not streaming and has a defined length, + # we cannot have num_shards > dataset_size. + if not load_dataset_kwargs.get("streaming", False) and hasattr( + ds, "__len__" + ): + dataset_size = len(ds) + if num_shards > dataset_size: + raise ValueError( + f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." + f"Please decrease num_shards_per_rank." + ) + + ds = ds.to_iterable_dataset(num_shards=num_shards) + + # Shuffle the dataset + if self._shuffle_buffer_size and self._shuffle_buffer_size > 0: + ds = ds.shuffle(seed=self._seed, buffer_size=self._shuffle_buffer_size) + + # Distribute across ranks + if world_size > 1: + ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) + + # Apply filtering if specified + if filter_fn: + filter_kwargs = filter_kwargs or {} + ds = ds.filter(filter_fn, **filter_kwargs) + + self._ds = ds + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Iterate through the dataset infinitely. + + It will restart from the beginning after exhausting the dataset. + + If shuffle_buffer_size is set, it will shuffle the dataset at the beginning of each epoch + when set_epoch is called. + + An additional metric "num_epochs" is added to the sample. + """ + epoch_ds = self._ds + + while True: # Infinite iteration + epoch_seed = self._seed + self._num_epochs + epoch_ds.set_epoch(epoch_seed) + epoch_iterator = iter(epoch_ds) + samples_yielded = 0 + + try: + for sample in epoch_iterator: + # NOTE: We apply transforms here instead of using .map() call + # to work around https://github.com/huggingface/datasets/issues/7630 + # where .map() can cause incorrect resumption from a checkpoint. + sample = self._apply_transforms(sample) + + # Track the number of epochs completed for each dataset. This is + # especially useful when interleaving multiple datasets, but + # also necessary to track dataset-level metrics. + metric_num_epochs = Metric( + dataset_name=self.dataset_name, + name="num_epochs", + value=self._num_epochs, + agg_type=AggregationType.MAX, + ) + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].append(metric_num_epochs) + + samples_yielded += 1 + yield sample + + except StopIteration: + pass # Iterator is exhausted, which is expected. + except Exception as e: + logger.warning( + f"Dataset {self.dataset_name} encountered an unexpected error: {e}." + ) + raise + + # Check if we got zero samples - this might indicate an issue + if samples_yielded == 0: + logger.warning( + f"Dataset {self.dataset_name} epoch {self._num_epochs} yielded 0 samples - potential issue!" + ) + + # Epoch complete - increment and continue infinite loop + self._num_epochs += 1 + + # Reset to the base dataset for the next epoch's shuffling. + epoch_ds = self._ds + + def state_dict(self) -> Dict[str, Any]: + """ + The dataset returns its own state directly, without namespacing. + """ + hf_state = self._ds.state_dict() + state = { + "num_epochs": self._num_epochs, + "seed": self._seed, + "hf_dataset_state": hf_state, + } + return state + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load state from checkpoint, including restoring the state of the + Hugging Face IterableDataset. + """ + self._num_epochs = state_dict["num_epochs"] + hf_state = state_dict["hf_dataset_state"] + + # HF is responsible for resuming the dataset state + # where it last left off + self._ds.load_state_dict(hf_state) \ No newline at end of file diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py new file mode 100644 index 0000000000..cbfe36338c --- /dev/null +++ b/torchtune/datasets/_interleaved.py @@ -0,0 +1,115 @@ +import collections +import logging +import math +from typing import Any, Dict, Iterator, List + +import torch + +from torchtune.datasets._iterable_base import TuneIterableDataset + +logger = logging.getLogger(__name__) + + +class InterleavedDataset(TuneIterableDataset): + """Infinitely interleaves multiple TuneIterableDatasets according to a list of weights. + - The weights are normalized to sum to 1.0. + - This dataset is responsible for managing the state of its child datasets + to ensure correct checkpointing and resumption. + + Args: + datasets (List[TuneIterableDataset]): List of TuneIterableDatasets to interleave. + weights (List[float]): List of weights for each dataset. Must sum to 1.0. + seed (int): Seed for sampling. + dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + """ + + def __init__( + self, + datasets: List[TuneIterableDataset], + weights: List[float], + seed: int, + dataset_name: str = "interleaved_dataset", + ): + self._dataset_name = dataset_name + + # Preserve original order for weighted sampling + self._dataset_names = [ds.dataset_name for ds in datasets] + + # Create a name-to-dataset mapping for robust state management + self._datasets: Dict[str, TuneIterableDataset] = { + ds.dataset_name: ds for ds in datasets + } + + # Validate unique dataset names upfront - fail fast with clear error + names = self._dataset_names + if len(names) != len(set(names)): + duplicates = [ + name for name, count in collections.Counter(names).items() if count > 1 + ] + raise ValueError( + f"Duplicate dataset names detected: {duplicates}. All {names=}" + f"Please provide a unique 'dataset_name' for each dataset in the interleaved list." + ) + + self._sampling_generator = torch.Generator().manual_seed(seed) + + # Normalize weights to sum to 1 + #TODO: make it a property? rely on ds.weight? + total_weight = sum(weights) + self._weights = torch.tensor( + [w / total_weight for w in weights], dtype=torch.float + ) + if not math.isclose(total_weight, 1.0, rel_tol=1e-9): + logger.warning( + f"Interleaved dataset normalized weights to sum to 1.0. Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + ) + + @property + def dataset_name(self) -> str: + return self._dataset_name + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Interleave samples from child infinite datasets""" + child_iters = {name: iter(ds) for name, ds in self._datasets.items()} + + while True: + # Sample which dataset to use + ds_idx = torch.multinomial( + self._weights, 1, replacement=True, generator=self._sampling_generator + ).item() + + # Sample an index, then get the name for safe lookup + ds_name = self._dataset_names[ds_idx] + + try: + sample = next(child_iters[ds_name]) + yield sample + except StopIteration: + # Per the design, child datasets must be infinite. + # We re-initialize to allow for continuous operation but warn loudly + # as this may indicate a design problem in the child dataset. + logger.warning( + f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " + "This is unexpected for an infinite dataset. Re-initializing its iterator." + ) + child_iters[ds_name] = iter(self._datasets[ds_name]) + sample = next(child_iters[ds_name]) + yield sample + + def state_dict(self) -> Dict[str, Any]: + """Save state for the interleaver and its children.""" + # The parent is responsible for namespacing the child states. + child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} + return { + "sampling_generator_state": self._sampling_generator.get_state(), + "child_states": child_states, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load state for the interleaver and its children.""" + self._sampling_generator.set_state(state_dict["sampling_generator_state"]) + child_states = state_dict["child_states"] + for name, ds in self._datasets.items(): + if name in child_states: + # Pass the raw state dict to the child + ds.load_state_dict(child_states[name]) \ No newline at end of file diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py new file mode 100644 index 0000000000..725810541c --- /dev/null +++ b/torchtune/datasets/_iterable_base.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator + +from torch.utils.data import IterableDataset + + +class TuneIterableDataset(IterableDataset, ABC): + """ + Abstract base class for all torchtune iterable datasets. + It defines the minimal, consistent interface required for all dataset + implementations to ensure they are compatible with the training loop, + checkpointing, and metric logging systems. + """ + + @property + @abstractmethod + def dataset_name(self) -> str: + """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ + Returns an infinite iterator over the dataset. Each implementation is responsible + for its own iteration logic, including shuffling and making it an infinite stream. + """ + pass + + @abstractmethod + def state_dict(self) -> Dict[str, Any]: + """Returns a state dictionary for checkpointing""" + pass + + @abstractmethod + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load state from a state dictionary, used when resuming from a checkpoint.""" + pass \ No newline at end of file diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 2e74ec66a0..70bfb75fd5 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Mapping, Optional, Dict import numpy as np from datasets import load_dataset @@ -14,6 +14,8 @@ from torchtune.data._messages import validate_messages from torchtune.modules.transforms import Transform +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset class SFTDataset(Dataset): @@ -178,3 +180,96 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: tokenized_dict = transformed_sample return tokenized_dict + + +class SFTOutputTransform(Transform): + """ + Output transform to be used in SFT recipes as an input to TuneIterableDataset. + It takes tokenized inputs with "tokens" and "mask" keys and + creates the "labels" key for SFT training. + + The labels are created by: + 1. Shifting tokens by 1 position (for autoregressive training) + 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX + 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end + """ + + def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: + # Create a copy to avoid modifying the original + tokenized_dict = dict(sample) + + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): + keys_str = ", ".join(tokenized_dict.keys()) + raise ValueError( + f"SFTOutputTransform expects 'tokens' and 'mask' keys. " + f"Got keys: {keys_str}" + ) + + # Create labels for SFT training + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"][1:], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"][1:], + ) + ) + tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + + return tokenized_dict + + +def sft_iterable_dataset( + model_transform: Transform, + *, + message_transform: Transform, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + **load_dataset_kwargs: Dict[str, Any], +) -> HfIterableDataset: + """ + Creates an SFT-ready iterable dataset with appropriate output transform. + + Args: + model_transform (Transform): Usually the tokenizer + message_transform (Transform): Transform to convert raw data to messages + shuffle_buffer_size (Optional[int]): Buffer size for shuffling + seed (int): Random seed for shuffling + num_shards_per_rank (int): Target shards per worker + dataset_name (Optional[str]): Name for metrics namespacing + filter_fn (Optional[Callable]): Filter function + filter_kwargs (Optional[Dict[str, Any]]): Filter function kwargs + **load_dataset_kwargs: Args passed to load_dataset + + Returns: + HfIterableDataset: Configured for SFT training + + Example: + >>> from torchtune.data import AlpacaToMessages + >>> message_transform = AlpacaToMessages(train_on_input=False) + >>> ds = sft_iterable_dataset( + ... message_transform=message_transform, + ... model_transform=tokenizer, + ... path="tatsu-lab/alpaca" + ... ) + """ + + output_transform = SFTOutputTransform() + + return HfIterableDataset( + message_transform=message_transform, + model_transform=model_transform, + output_transform=output_transform, + metric_transform=StandardMetricTransform(), + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index ac49b56d63..77667aa579 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -7,9 +7,11 @@ from typing import Any, Callable, Optional, Union from torchtune.data import ShareGPTToMessages +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -94,3 +96,68 @@ def slimorca_dataset( ) return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len) return ds + + +def slimorca_iterable_dataset( + model_transform: ModelTokenizer, + *, + source: str = "Open-Orca/SlimOrca-Dedup", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = False, + new_system_prompt: Optional[str] = None, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for SlimOrca-style conversational datasets using iterable approach. + + This creates an infinite iterable dataset that automatically shards and shuffles data, + making it suitable for step-based training without explicit epoch boundaries. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + column_map (Optional[dict[str, str]]): mapping from expected "conversations" column + to actual column name in dataset. If None, uses default "conversations". + train_on_input (bool): Whether to train on input or mask it. Default is False. + new_system_prompt (Optional[str]): If specified, prepend system message to every sample. + shuffle_buffer_size (Optional[int]): Size of shuffle buffer. If None or 0, no shuffling. + seed (int): Seed for shuffling. Default is 42. + num_shards_per_rank (int): Target number of shards per worker. Default is 64. + dataset_name (Optional[str]): Name for metrics. If None, auto-generated from source. + filter_fn (Optional[Callable]): Filter function to apply to dataset. + filter_kwargs (Optional[dict[str, Any]]): Kwargs for filter function. + **load_dataset_kwargs: Additional kwargs for load_dataset. + + Returns: + HfIterableDataset: Configured iterable dataset + + Example: + >>> from torchtune.datasets import slimorca_iterable_dataset + >>> ds = slimorca_iterable_dataset(shuffle_buffer_size=1000) + >>> for sample in ds: + >>> print(sample["tokens"][:10]) # First 10 tokens + """ + message_transform = ShareGPTToMessages( + train_on_input=train_on_input, + column_map=column_map, + new_system_prompt=new_system_prompt, + ) + + return sft_iterable_dataset( + source=source, + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 39b8989284..5f29e038c6 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -28,6 +28,7 @@ from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer from torchtune.training.checkpointing._utils import get_most_recent_checkpoint from torchtune.training.memory import OptimizerInBackwardWrapper +from torchtune.data import MetricsAggregator log = utils.get_logger("DEBUG") import torchdata @@ -47,6 +48,7 @@ class TrainingProgress: total_training_steps: Optional[int] = None dataloader_state_dict: Optional[dict[str, Any]] = None val_dataloader_state_dict: Optional[dict[str, Any]] = None + metrics_aggregator_state_dict: Optional[dict[str, Any]] = None def state_dict(self) -> dict[str, object]: return { @@ -58,6 +60,7 @@ def state_dict(self) -> dict[str, object]: "total_training_steps": self.total_training_steps, training.DATALOADER_KEY: self.dataloader_state_dict, training.VAL_DATALOADER_KEY: self.val_dataloader_state_dict, + "metrics_aggregator_state_dict": self.metrics_aggregator_state_dict, } @@ -442,6 +445,7 @@ def load_distributed_checkpoint( adapter_config: Optional[dict[str, Any]] = None, dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader] = None, single_device: bool = False, + metrics_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ This method is used to resume training from a distributed checkpoint state. @@ -459,6 +463,7 @@ def load_distributed_checkpoint( checkpoint_dict: dict[str, Any] = {} model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() + metrics_aggregator_state_dict = metrics_aggregator.state_dict() if metrics_aggregator else {} # Hack to properly initialize the learning rate scheduler # TODO: Find a better way to do this, possibly by including the following @@ -481,6 +486,7 @@ def load_distributed_checkpoint( "steps_run": 0, "total_training_steps": 0, training.DATALOADER_KEY: dataloader.state_dict() if dataloader else {}, + "metrics_aggregator_state_dict": metrics_aggregator_state_dict, } ) From 2212b19d861ed07381def0b8a93680549a888617 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:59:02 -0400 Subject: [PATCH 15/48] update tests --- tests/torchtune/datasets/test_interleaved.py | 80 ++++++++++- .../torchtune/datasets/test_iterable_utils.py | 126 ++++++++++++++++++ 2 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 tests/torchtune/datasets/test_iterable_utils.py diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 1190d6d774..3c7a1c6fae 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,15 +4,89 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile +from pathlib import Path from itertools import islice from typing import Any, Dict, Iterator from unittest.mock import patch import pytest import torch +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform +from torchtune.datasets import InterleavedDataset, HfIterableDataset + +# Import test utilities +from .test_iterable_utils import collate_with_metrics, generate_ckpt + +# Test Constants +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +SEED = 42 +BATCH_SIZE = 5 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + ) + -from torchtune.data import AggregationType, Metric, MetricsAggregator -from torchtune.datasets import InterleavedDataset, TuneIterableDataset +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def medium_dataset_file(tmp_data_dir): + path = tmp_data_dir / "medium_data.json" + create_test_json_file(path, MEDIUM_DATASET_SIZE, offset=100) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + **kwargs + ) + return _create_dataset class TestInterleavedDataset: @@ -90,7 +164,7 @@ def test_metrics_aggregation( # Process some samples TOTAL_SAMPLES = 200 - for sample in islice(iter(interleaved), 200): + for sample in islice(iter(interleaved), TOTAL_SAMPLES): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging() diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py new file mode 100644 index 0000000000..8d4d6d7849 --- /dev/null +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from torch.utils.data import DataLoader +from torchtune.data import padded_collate_sft +from torchtune.data._metrics import MetricsAggregator + + +def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Use torchtune's standard SFT collate function + collated = padded_collate_sft(clean_batch) + collated["metrics"] = all_metrics + return collated + + +def generate_ckpt( + dataloader: DataLoader, + aggregator: MetricsAggregator, + steps_before_checkpoint: int, + steps_after_checkpoint: int, + resume_dataloader: Optional[DataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, +) -> Dict[str, Any]: + """ + Generates a checkpoint by running through data and saving checkpoint mid-stream. + Optionally, a second dataloader and aggregator can be given to resume from ckpt + and run steps_after_checkpoint to match the first one. + + Args: + dataloader: The dataloader to test + aggregator: The metrics aggregator to use + steps_before_checkpoint: Number of steps to run before saving checkpoint + steps_after_checkpoint: Number of steps to run after checkpoint + resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. + resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. + + Returns dict with batches/metrics from both pre and post checkpoint runs. + """ + iterator = iter(dataloader) + + # Collect batches before and after checkpoint + batches = [] + checkpoint_state = None + metrics_at_checkpoint = {} + + total_steps = steps_before_checkpoint + steps_after_checkpoint + + for idx, batch in enumerate(iterator): + batches.append(batch) + + # Process metrics + if "metrics" in batch: + aggregator.update(batch.pop("metrics")) + + # Save checkpoint state after steps_before_checkpoint + if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based + checkpoint_state = { + "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), + } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") + + # Stop after total steps + if idx == total_steps - 1: + break + + # Split batches + pre_checkpoint_batches = batches[:steps_before_checkpoint] + post_checkpoint_batches = batches[steps_before_checkpoint:] + + # Resume with new instances if provided + resumed_batches = [] + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances + resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) + resume_iterator = iter(resume_dataloader) + + # Collect only the post-checkpoint batches when resuming + for idx, batch in enumerate(resume_iterator): + resumed_batches.append(batch) + + # Process metrics + if "metrics" in batch: + resume_aggregator.update(batch.pop("metrics")) + + # Stop after steps_after_checkpoint + if idx == steps_after_checkpoint - 1: + break + + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + + return { + # Original run + "pre_checkpoint_batches": pre_checkpoint_batches, + "post_checkpoint_batches": post_checkpoint_batches, + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + # Resumed run + "resumed_batches": resumed_batches, + "resumed_metrics": resumed_metrics, + # Internal state for loading - only if someone needs to manually load + "_checkpoint_state": checkpoint_state, + } \ No newline at end of file From 2eb68b6d00b301bdde940e8cbfd4efef6a5cba60 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:40:43 -0700 Subject: [PATCH 16/48] linter --- recipes/full_finetune_distributed.py | 148 +++++++++++------- .../torchtune/data/test_metrics_aggregator.py | 12 +- .../torchtune/data/test_metrics_transform.py | 18 +-- tests/torchtune/datasets/test_hf_iterable.py | 86 +++++----- tests/torchtune/datasets/test_interleaved.py | 32 ++-- .../torchtune/datasets/test_iterable_utils.py | 27 ++-- torchtune/data/__init__.py | 14 +- torchtune/data/_aggregator.py | 51 +++--- torchtune/data/_metrics.py | 13 +- torchtune/datasets/__init__.py | 2 +- torchtune/datasets/_alpaca.py | 5 +- torchtune/datasets/_hf_iterable.py | 30 ++-- torchtune/datasets/_interleaved.py | 34 ++-- torchtune/datasets/_iterable_base.py | 16 +- torchtune/datasets/_sft.py | 38 ++--- torchtune/datasets/_slimorca.py | 13 +- .../checkpointing/_checkpoint_client.py | 6 +- 17 files changed, 307 insertions(+), 238 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index f34ccc6a7e..a4eb87e7d6 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -9,11 +9,11 @@ import time from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from warnings import warn import torch -from omegaconf import DictConfig, ListConfig +from omegaconf import dictConfig, listConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -22,11 +22,10 @@ from torch.optim import Optimizer from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import padded_collate_packed, MetricsAggregator -from torchtune.datasets import ConcatDataset, InterleavedDataset +from torchtune.data import MetricsAggregator, padded_collate_packed +from torchtune.datasets import InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils @@ -120,7 +119,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): has example commands for how to kick-off training. Args: - cfg (DictConfig): OmegaConf object parsed from yaml file + cfg (dictConfig): OmegaConf object parsed from yaml file Raises: ValueError: If ``dtype`` is set to fp16. @@ -130,7 +129,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ - def __init__(self, cfg: DictConfig) -> None: + def __init__(self, cfg: dictConfig) -> None: device_type = cfg.device self._device = utils.get_device(device=device_type) self._dtype = training.get_dtype(cfg.dtype, device=self._device) @@ -274,12 +273,12 @@ def __init__(self, cfg: DictConfig) -> None: seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) ) self.global_step = 0 - + # Step-based training support self.num_training_steps = cfg.num_training_steps self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) self._metrics_aggregator = None # Will be initialized in setup - + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. @@ -304,14 +303,16 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: "Are you sure you passed in the right recipe checkpoint?" ) from e - def setup(self, cfg: DictConfig) -> None: + def setup(self, cfg: dictConfig) -> None: """ Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ if cfg.get("dataset_val") is not None: - raise NotImplementedError("Validation is not supported yet with iterable datasets.") - + raise NotImplementedError( + "Validation is not supported yet with iterable datasets." + ) + if self.fsdp_cpu_offload: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU @@ -333,7 +334,7 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool - if isinstance(compile, DictConfig): + if isinstance(compile, dictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) @@ -470,7 +471,7 @@ def setup(self, cfg: DictConfig) -> None: def _setup_lr_scheduler( self, - cfg_lr_scheduler: Optional[DictConfig], + cfg_lr_scheduler: Optional[dictConfig], num_training_steps: int, last_epoch: int, ) -> Optional[Optimizer]: @@ -479,7 +480,7 @@ def _setup_lr_scheduler( It supports both standard optimization and optimizer-in-backward cases. Args: - cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + cfg_lr_scheduler (Optional[dictConfig]): The learning rate scheduler configuration. num_training_steps (int): The total number of training steps. last_epoch (int): The index of the last epoch. @@ -518,14 +519,14 @@ def _setup_lr_scheduler( return lr_scheduler def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None + self, cfg_profiler: Optional[dictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler """ # Missing profiler section in config, assume disabled if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) + cfg_profiler = dictConfig({"enabled": False}) # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: @@ -552,7 +553,7 @@ def _setup_profiler( def _setup_model( self, - cfg_model: DictConfig, + cfg_model: dictConfig, enable_activation_checkpointing: bool, enable_activation_offloading: bool, activation_offloading_use_streams: bool, @@ -710,7 +711,7 @@ def _setup_model( def _setup_optimizer( self, - cfg_optimizer: DictConfig, + cfg_optimizer: dictConfig, optimizer_in_bwd: bool = False, opt_state_dict: Optional[dict[str, Any]] = None, ) -> Optional[Optimizer]: @@ -763,7 +764,7 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: Union[DictConfig, ListConfig], + cfg_dataset: Union[dictConfig, listConfig], batch_size: int, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, @@ -776,7 +777,7 @@ def _setup_data( iterable_datasets = [] weights = [] cfg_dataset_list = cfg_dataset - if not isinstance(cfg_dataset_list, ListConfig): + if not isinstance(cfg_dataset_list, listConfig): cfg_dataset_list = [cfg_dataset_list] for ds_cfg in cfg_dataset_list: @@ -787,25 +788,25 @@ def _setup_data( # 2. Interleave datasets if any if len(iterable_datasets) > 1: ds = InterleavedDataset( - datasets=iterable_datasets, + datasets=iterable_datasets, weights=weights, seed=self.seed, ) else: ds = iterable_datasets[0] - + # 3. Apply packing # TODO: follow up PR packed = False # 4. Define a collate function wrapper to handle metrics base_collate_fn = ( - padded_collate_packed - if packed - else _get_component_from_path(collate_fn) + padded_collate_packed if packed else _get_component_from_path(collate_fn) ) - def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + def _collate_with_metrics_wrapper( + batch: list[dict[str, Any]] + ) -> dict[str, Any]: # TODO: handling of metrics should prob be done in collate_fn. # putting this here for now to avoid making more changes to this PR. all_metrics = [] @@ -814,7 +815,7 @@ def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any] if "metrics" in sample: all_metrics.extend(sample.pop("metrics")) clean_batch.append(sample) - + collated_batch = base_collate_fn(clean_batch) collated_batch["metrics"] = all_metrics return collated_batch @@ -901,12 +902,16 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, - optimizer=(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), - training_progress = TrainingProgress( + optimizer=( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + training_progress=TrainingProgress( seed=self.seed, - epochs_run=0, # TODO: not needed. To be deprecated. - total_epochs=1, # TODO: not needed. To be deprecated. - max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. + epochs_run=0, # TODO: not needed. To be deprecated. + total_epochs=1, # TODO: not needed. To be deprecated. + max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. steps_run=self.global_step, total_training_steps=self.num_training_steps, dataloader_state_dict=self._dataloader.state_dict(), @@ -915,10 +920,10 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): if self._val_dataloader is not None else {} ), - #FIXME: add to load_ckpt and TrainingProgress too + # FIXME: add to load_ckpt and TrainingProgress too metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), - epoch=epoch, # TODO: not needed. To be deprecated. + epoch=epoch, # TODO: not needed. To be deprecated. step=step, single_device=False, full_tensors=full_tensors, @@ -945,9 +950,11 @@ def train(self) -> None: num_tokens = 0 self._profiler.start() - - pbar = tqdm(initial=self.global_step, total=self.num_training_steps, desc="Training") - + + pbar = tqdm( + initial=self.global_step, total=self.num_training_steps, desc="Training" + ) + dataloader_iter = iter(self._dataloader) batch_count = 0 @@ -955,18 +962,19 @@ def train(self) -> None: try: batch = next(dataloader_iter) except StopIteration: - self._logger.warning("Dataloader iterator exhausted unexpectedly. Ending training.") + self._logger.warning( + "Dataloader iterator exhausted unexpectedly. Ending training." + ) break - + if "metrics" in batch: self._metrics_aggregator.update(batch.pop("metrics")) - - # Start tracking CUDA memory for active steps for just the first epoch + + # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps + self.profiler_warmup_steps + and batch_count == self.profiler_wait_steps + self.profiler_warmup_steps and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -975,9 +983,7 @@ def train(self) -> None: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() + current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() num_tokens += current_num_tokens with self.train_context( @@ -1038,32 +1044,54 @@ def train(self) -> None: loss_to_log = running_loss.detach().item() / num_tokens pbar.update(1) - pbar.set_description(f"Step: {self.global_step}|Loss: {loss_to_log:.4f}") - + pbar.set_description( + f"Step: {self.global_step}|Loss: {loss_to_log:.4f}" + ) + # Log per-step metrics - if self.global_step % self._log_every_n_steps == 0 and self._is_rank_zero: + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": get_lr(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), - "tokens_per_second_per_gpu": (num_tokens / self.parallel_dims.non_data_parallel_size) / (time_per_step * self.world_size), + "lr": get_lr( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + "tokens_per_second_per_gpu": ( + num_tokens / self.parallel_dims.non_data_parallel_size + ) + / (time_per_step * self.world_size), } if self._log_peak_memory_stats: log_dict.update(training.get_memory_stats(device=self._device)) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict(log_dict, step=self.global_step) - + # Log dataset metrics # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if self.global_step % self._dataset_metrics_log_freq == 0 and self._is_rank_zero: - dataset_metrics = self._metrics_aggregator.get_metrics_for_logging(prefix="train") + if ( + self.global_step % self._dataset_metrics_log_freq == 0 + and self._is_rank_zero + ): + dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( + prefix="train" + ) self._metric_logger.log_dict(dataset_metrics, step=self.global_step) - + # Save checkpoint if specified by user - if self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0: - self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=False) - + if ( + self.save_every_n_steps is not None + and self.global_step % self.save_every_n_steps == 0 + ): + self.save_checkpoint( + epoch=0, step=self.global_step, full_tensors=False + ) + # Reset running stats for the next step running_loss = 0 num_tokens = 0 @@ -1102,7 +1130,7 @@ def cleanup(self) -> None: @config.parse -def recipe_main(cfg: DictConfig) -> None: +def recipe_main(cfg: dictConfig) -> None: """ Entry point for the recipe. diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 382d968704..69a2f32967 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,9 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections import pytest -from unittest.mock import patch from torchtune.data import AggregationType, Metric, MetricsAggregator @@ -63,7 +61,9 @@ def test_distribution_metrics(self): assert result["train/test/dist_metric_mean"] == 5.5 assert result["train/test/dist_metric_min"] == 1 assert result["train/test/dist_metric_max"] == 10 - assert result["train/test/dist_metric_p50"] == 5 # Median of 1-10 is 5 (index 4, value 5) + assert ( + result["train/test/dist_metric_p50"] == 5 + ) # Median of 1-10 is 5 (index 4, value 5) def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -107,7 +107,7 @@ def test_state_management(self): def test_multiple_datasets(self): """Test that metrics from multiple datasets are correctly namespaced.""" aggregator = MetricsAggregator() - + metrics = [ Metric("dataset1", "samples", 100, AggregationType.SUM), Metric("dataset2", "samples", 200, AggregationType.SUM), @@ -117,7 +117,7 @@ def test_multiple_datasets(self): aggregator.update(metrics) result = aggregator.get_metrics_for_logging(prefix="train") - + assert result["train/dataset1/samples"] == 100 assert result["train/dataset2/samples"] == 200 assert result["train/dataset1/tokens"] == 1000 @@ -146,4 +146,4 @@ def test_prefix_handling(self): # Test without prefix result_no_prefix = aggregator.get_metrics_for_logging() assert result_no_prefix["test_ds/metric1"] == 42 - assert result_no_prefix["test_ds/metric2"] == 84 \ No newline at end of file + assert result_no_prefix["test_ds/metric2"] == 84 diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 1eed534e42..2c8f3023b2 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -6,7 +6,7 @@ import pytest -from torchtune.data import AggregationType, Metric, StandardMetricTransform +from torchtune.data import AggregationType, StandardMetricTransform class TestStandardMetricTransform: @@ -16,7 +16,7 @@ def test_dataset_name_not_set_raises_error(self): """Test that using transform without setting dataset name raises error.""" transform = StandardMetricTransform() sample = {"tokens": [1, 2, 3]} - + with pytest.raises(RuntimeError, match="set_dataset_name"): transform(sample) @@ -24,31 +24,31 @@ def test_basic_metrics_generation(self): """Test that transform generates expected metrics for a sample.""" transform = StandardMetricTransform() transform.set_dataset_name("test_dataset") - + sample = {"tokens": [1, 2, 3, 4, 5]} result = transform(sample) - + # Should preserve original sample data assert result["tokens"] == [1, 2, 3, 4, 5] - + # Should add metrics assert "metrics" in result metrics = result["metrics"] assert len(metrics) == 3 - + # Check each metric for metric in metrics: if metric.name == "samples_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 1 assert metric.agg_type == AggregationType.SUM - + elif metric.name == "tokens_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.SUM - + elif metric.name == "seq_len": assert metric.dataset_name == "test_dataset" assert metric.value == 5 - assert metric.agg_type == AggregationType.DISTRIBUTION \ No newline at end of file + assert metric.agg_type == AggregationType.DISTRIBUTION diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 4cf303c6fd..a263258ae8 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,19 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections -import tempfile -from pathlib import Path from itertools import islice -from typing import Any, Callable, Dict, List, Optional -from unittest.mock import Mock, patch +from pathlib import Path +from typing import Any, Optional import pytest -import torch -from torch.nn.utils.rnn import pad_sequence from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform, padded_collate_sft +from torchtune.data import ( + MetricsAggregator, + padded_collate_sft, + StandardMetricTransform, +) from torchtune.datasets import HfIterableDataset @@ -47,7 +46,7 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None ) -def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" # Extract metrics first all_metrics = [] @@ -73,21 +72,24 @@ def generate_ckpt( steps_after_checkpoint: int, resume_dataloader: Optional[StatefulDataLoader] = None, resume_aggregator: Optional[MetricsAggregator] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. Optionally, a second dataloader and aggregator can be given to resume from ckpt and run steps_after_checkpoint to match the first one. Args: - dataloader: The dataloader to test - aggregator: The metrics aggregator to use - steps_before_checkpoint: Number of steps to run before saving checkpoint - steps_after_checkpoint: Number of steps to run after checkpoint - resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. - - Returns dict with batches/metrics from both pre and post checkpoint runs. + dataloader (StatefulDataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use + steps_before_checkpoint (int): Number of steps to run before saving checkpoint + steps_after_checkpoint (int): Number of steps to run after checkpoint + resume_dataloader (Optional[StatefulDataLoader]): Optional new dataloader to test resuming. + If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. + + Returns: + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. """ iterator = iter(dataloader) @@ -179,11 +181,12 @@ def small_dataset_file(tmp_data_dir): @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( data_file: str, dataset_name: str = "test_dataset", shuffle: bool = False, - **kwargs + **kwargs, ) -> HfIterableDataset: return HfIterableDataset( path="json", @@ -194,8 +197,9 @@ def _create_dataset( shuffle_buffer_size=10 if shuffle else 0, metric_transform=StandardMetricTransform(), num_shards_per_rank=2, - **kwargs + **kwargs, ) + return _create_dataset @@ -223,7 +227,7 @@ def test_default_dataset_name(self, small_dataset_file): path="json", data_files=small_dataset_file, split="train", - dataset_name = "my_dataset", + dataset_name="my_dataset", seed=SEED, metric_transform=StandardMetricTransform(), num_shards_per_rank=4, @@ -288,8 +292,8 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Get samples from two passes through the dataset - epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] @@ -303,8 +307,8 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Collect full epochs to compare - epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] @@ -317,23 +321,35 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" # But should contain the same set of IDs - assert set(first_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" - assert set(second_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + assert set(first_epoch_samples) == set( + range(SMALL_DATASET_SIZE) + ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" + assert set(second_epoch_samples) == set( + range(SMALL_DATASET_SIZE) + ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" def test_epoch_tracking(self, dataset_factory, small_dataset_file): """Test that epoch number is correctly tracked across dataset restarts.""" dataset = dataset_factory(small_dataset_file, shuffle=False) - + # Two epoch samples - epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # All should have epoch 0 - epoch_values = [epoch_metric.value for epoch_metric in first_epoch_samples["metrics"]] - assert all(epoch_value == 0 for epoch_value in epoch_values), f"Epoch values should be 0, got {epoch_values}" - + epoch_values = [ + epoch_metric.value for epoch_metric in first_epoch_samples["metrics"] + ] + assert all( + epoch_value == 0 for epoch_value in epoch_values + ), f"Epoch values should be 0, got {epoch_values}" + # All should have epoch 1 - epoch_values = [epoch_metric.value for epoch_metric in second_epoch_samples["metrics"]] - assert all(epoch_value == 1 for epoch_value in epoch_values), f"Epoch values should be 1, got {epoch_values}" + epoch_values = [ + epoch_metric.value for epoch_metric in second_epoch_samples["metrics"] + ] + assert all( + epoch_value == 1 for epoch_value in epoch_values + ), f"Epoch values should be 1, got {epoch_values}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 3c7a1c6fae..e06ef670c1 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,18 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import tempfile -from pathlib import Path from itertools import islice -from typing import Any, Dict, Iterator +from pathlib import Path from unittest.mock import patch import pytest import torch from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform -from torchtune.datasets import InterleavedDataset, HfIterableDataset +from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities from .test_iterable_utils import collate_with_metrics, generate_ckpt @@ -69,11 +67,12 @@ def medium_dataset_file(tmp_data_dir): @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( data_file: str, dataset_name: str = "test_dataset", shuffle: bool = False, - **kwargs + **kwargs, ) -> HfIterableDataset: return HfIterableDataset( path="json", @@ -84,8 +83,9 @@ def _create_dataset( shuffle_buffer_size=10 if shuffle else 0, metric_transform=StandardMetricTransform(), num_shards_per_rank=2, - **kwargs + **kwargs, ) + return _create_dataset @@ -107,10 +107,10 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( - datasets=[ds3, ds4], - weights=[0.5, 1.5], + datasets=[ds3, ds4], + weights=[0.5, 1.5], seed=SEED, - dataset_name="test_interleaved" # Sum = 2.0 != 1.0 + dataset_name="test_interleaved", # Sum = 2.0 != 1.0 ) # Check that weights were normalized @@ -163,8 +163,8 @@ def test_metrics_aggregation( aggregator = MetricsAggregator() # Process some samples - TOTAL_SAMPLES = 200 - for sample in islice(iter(interleaved), TOTAL_SAMPLES): + total_samples = 200 + for sample in islice(iter(interleaved), total_samples): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging() @@ -181,11 +181,11 @@ def test_metrics_aggregation( calculated_total_samples = ( metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] ) - assert calculated_total_samples == TOTAL_SAMPLES + assert calculated_total_samples == total_samples # Test that ratio is approximately correct - ds1_ratio = metrics["ds1/samples_seen"] / TOTAL_SAMPLES - ds2_ratio = metrics["ds2/samples_seen"] / TOTAL_SAMPLES + ds1_ratio = metrics["ds1/samples_seen"] / total_samples + ds2_ratio = metrics["ds2/samples_seen"] / total_samples # Allow 10% tolerance due to randomness assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" @@ -232,5 +232,3 @@ def create_interleaved(): assert ( result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" - - \ No newline at end of file diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index 8d4d6d7849..5a4fbda8a5 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from torch.utils.data import DataLoader from torchtune.data import padded_collate_sft from torchtune.data._metrics import MetricsAggregator -def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" all_metrics = [] clean_batch = [] @@ -36,21 +36,24 @@ def generate_ckpt( steps_after_checkpoint: int, resume_dataloader: Optional[DataLoader] = None, resume_aggregator: Optional[MetricsAggregator] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. Optionally, a second dataloader and aggregator can be given to resume from ckpt and run steps_after_checkpoint to match the first one. Args: - dataloader: The dataloader to test - aggregator: The metrics aggregator to use - steps_before_checkpoint: Number of steps to run before saving checkpoint - steps_after_checkpoint: Number of steps to run after checkpoint - resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. - - Returns dict with batches/metrics from both pre and post checkpoint runs. + dataloader (DataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use + steps_before_checkpoint (int): Number of steps to run before saving checkpoint + steps_after_checkpoint (int): Number of steps to run after checkpoint + resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. + If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. + + Returns: + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. """ iterator = iter(dataloader) @@ -123,4 +126,4 @@ def generate_ckpt( "resumed_metrics": resumed_metrics, # Internal state for loading - only if someone needs to manually load "_checkpoint_state": checkpoint_state, - } \ No newline at end of file + } diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index e1d7d687dd..09292b9ba9 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchtune.data._aggregator import MetricsAggregator from torchtune.data._collate import ( left_pad_sequence, padded_collate, @@ -24,6 +25,12 @@ ShareGPTToMessages, validate_messages, ) +from torchtune.data._metrics import ( + AggregationType, + Metric, + MetricTransform, + StandardMetricTransform, +) from torchtune.data._prompt_templates import ( ChatMLTemplate, GrammarErrorCorrectionTemplate, @@ -32,14 +39,7 @@ QuestionAnswerTemplate, SummarizeTemplate, ) -from torchtune.data._metrics import ( - AggregationType, - Metric, - MetricTransform, - StandardMetricTransform, -) from torchtune.data._utils import format_content_with_images, load_image, truncate -from torchtune.data._aggregator import MetricsAggregator __all__ = [ "AggregationType", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index f6b962c84c..9c933ddfa3 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -7,9 +7,8 @@ import ast import collections import logging -from typing import Any, Dict, List, Tuple +from typing import Any -import torch import torch.distributed as dist from torchtune.data._metrics import AggregationType, Metric @@ -34,16 +33,16 @@ class MetricsAggregator: def __init__(self, dist_window_size: int = 1000): # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} - self._state: Dict[Tuple[str, str], Dict[str, Any]] = {} + self._state: dict[tuple[str, str], dict[str, Any]] = {} # For distributions, we keep a window of values to compute percentiles self._dist_window_size = dist_window_size - def update(self, metrics: List[Metric]) -> None: + def update(self, metrics: list[Metric]) -> None: """Update internal state with new metrics. Args: - metrics: List of Metric objects + metrics (list[Metric]): list of Metric objects """ for metric in metrics: key = (metric.dataset_name, metric.name) @@ -75,7 +74,7 @@ def update(self, metrics: List[Metric]) -> None: state["counts"][metric.value] += 1 def _initialize_state( - self, key: Tuple[str, str], agg_type: AggregationType + self, key: tuple[str, str], agg_type: AggregationType ) -> None: """Initialize state for a new metric.""" self._state[key] = {"type": agg_type} @@ -93,15 +92,15 @@ def _initialize_state( elif agg_type == AggregationType.CATEGORICAL_COUNT: state["counts"] = collections.Counter() - def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: + def get_metrics_for_logging(self, prefix: str = "") -> dict[str, float]: """ Returns aggregated metrics ready for logging to wandb/tensorboard. Args: - prefix: Optional prefix like "train" or "valid" for metric keys + prefix (str): Optional prefix like "train" or "valid" for metric keys Returns: - Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value + dict[str, float]: Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value Ready to be logged directly: wandb.log(metrics) """ # Always compute local metrics first @@ -116,18 +115,16 @@ def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: # Format for logging with proper key structure return self._format_for_logging(metrics, prefix) - def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: + def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: """ Compute metrics from current state. For distributions and categoricals, expands into multiple entries. The dict format allows future extensions with additional fields. - Args: - None - Returns: - Dictionary mapping (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} + dict[tuple[str, str], dict[str, Any]]: dictionary mapping + (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} """ metrics = {} @@ -199,8 +196,8 @@ def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: return metrics def _compute_distributed_metrics( - self, local_metrics: Dict[Tuple[str, str], Dict[str, Any]] - ) -> Dict[Tuple[str, str], Dict[str, Any]]: + self, local_metrics: dict[tuple[str, str], dict[str, Any]] + ) -> dict[tuple[str, str], dict[str, Any]]: """ Performs distributed reduction on metrics. @@ -212,10 +209,11 @@ def _compute_distributed_metrics( This avoids complex tensor operations and handles all reduction in one pass. Args: - local_metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + local_metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping + (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} Returns: - Reduced metrics in same format as input + dict[tuple[str, str], dict[str, Any]]: Reduced metrics in same format as input Example: rank_1_metrics = @@ -278,17 +276,18 @@ def _compute_distributed_metrics( return reduced def _format_for_logging( - self, metrics: Dict[Tuple[str, str], Dict[str, Any]], prefix: str - ) -> Dict[str, float]: + self, metrics: dict[tuple[str, str], dict[str, Any]], prefix: str + ) -> dict[str, float]: """ Format metrics for wandb/tensorboard logging. Args: - metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - prefix: Optional prefix like "train" or "valid" + metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping + (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + prefix (str): Optional prefix like "train" or "valid" Returns: - Flat dict with string keys like "train/dataset1/tokens_seen" -> float + dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float """ formatted = {} @@ -303,7 +302,7 @@ def _format_for_logging( return formatted - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Serialize aggregator state. The state is almost directly serializable.""" serializable_state = {} for key, state in self._state.items(): @@ -320,7 +319,7 @@ def state_dict(self) -> Dict[str, Any]: serializable_state[str(key)] = state_copy return {"state": serializable_state, "dist_window_size": self._dist_window_size} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load aggregator state from checkpoint.""" self._dist_window_size = state_dict["dist_window_size"] @@ -339,4 +338,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: state["counts"] = collections.Counter(state["counts"]) deserialized_state[key] = state - self._state = deserialized_state \ No newline at end of file + self._state = deserialized_state diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py index f61d0e579e..7a38febb1e 100644 --- a/torchtune/data/_metrics.py +++ b/torchtune/data/_metrics.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Dict, Optional, Protocol, Union +from typing import Any, Callable, Optional, Protocol, Union class AggregationType(Enum): @@ -34,8 +34,11 @@ class Metric: class MetricTransform(Protocol): """Protocol for metric transforms.""" - def set_dataset_name(self, dataset_name: str) -> None: ... - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: ... + def set_dataset_name(self, dataset_name: str) -> None: + ... + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + ... class StandardMetricTransform(MetricTransform): @@ -67,7 +70,7 @@ def set_dataset_name(self, dataset_name: str) -> None: self.dataset_name = dataset_name self.new_metric = partial(Metric, dataset_name=dataset_name) - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if self.dataset_name is None or self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -92,4 +95,4 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: if "metrics" not in sample: sample["metrics"] = [] sample["metrics"].extend(metrics) - return sample \ No newline at end of file + return sample diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 4ea863169d..f5ecbb95ea 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -22,7 +22,7 @@ from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.datasets._slimorca import slimorca_dataset, slimorca_iterable_dataset from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 4225ab4bf5..bae7613729 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -9,11 +9,10 @@ from typing import Any, Callable, Optional, Union from torchtune.data._messages import AlpacaToMessages -from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -152,7 +151,7 @@ def alpaca_iterable_dataset( message_transform = AlpacaToMessages( train_on_input=train_on_input, column_map=column_map ) - + return sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 9a206445d5..0be4c0cc53 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -1,5 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging -from typing import Any, Callable, Dict, Iterator, List, Optional +from typing import Any, Callable, Iterator, Optional import torch import torch.distributed as dist @@ -37,8 +43,8 @@ class HfIterableDataset(TuneIterableDataset): dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated from the path, source, and split. filter_fn (Optional[Callable]): Filter function to apply to the dataset. - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to the filter function. - load_dataset_kwargs (Dict[str, Any]): Keyword arguments to pass to the load_dataset function. + filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. + load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. """ @@ -55,7 +61,7 @@ def __init__( num_shards_per_rank: int = 64, dataset_name: Optional[str] = None, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, + filter_kwargs: Optional[dict[str, Any]] = None, **load_dataset_kwargs, ): # Store configuration @@ -64,7 +70,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight # TODO: make it a property? + self._weight = weight # TODO: make it a property? # Create default transform if not provided self._metric_transform = metric_transform or StandardMetricTransform() @@ -98,7 +104,7 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name - def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" if self._message_transform is not None: sample = self._message_transform(sample) @@ -112,10 +118,10 @@ def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: def _setup_hf_dataset( self, - load_dataset_kwargs: Dict[str, Any], + load_dataset_kwargs: dict[str, Any], num_shards_per_rank: int, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, + filter_kwargs: Optional[dict[str, Any]] = None, ): """ Configures the Hugging Face dataset, including sharding, filtering, and @@ -185,7 +191,7 @@ def _setup_hf_dataset( self._ds = ds - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """Iterate through the dataset infinitely. It will restart from the beginning after exhausting the dataset. @@ -246,7 +252,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: # Reset to the base dataset for the next epoch's shuffling. epoch_ds = self._ds - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """ The dataset returns its own state directly, without namespacing. """ @@ -258,7 +264,7 @@ def state_dict(self) -> Dict[str, Any]: } return state - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ Load state from checkpoint, including restoring the state of the Hugging Face IterableDataset. @@ -268,4 +274,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # HF is responsible for resuming the dataset state # where it last left off - self._ds.load_state_dict(hf_state) \ No newline at end of file + self._ds.load_state_dict(hf_state) diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index cbfe36338c..53185993dd 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -1,7 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import collections import logging import math -from typing import Any, Dict, Iterator, List +from typing import Any, dict, Iterator import torch @@ -17,16 +23,19 @@ class InterleavedDataset(TuneIterableDataset): to ensure correct checkpointing and resumption. Args: - datasets (List[TuneIterableDataset]): List of TuneIterableDatasets to interleave. - weights (List[float]): List of weights for each dataset. Must sum to 1.0. + datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. + weights (list[float]): list of weights for each dataset. Must sum to 1.0. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + + Raises: + ValueError: If duplicate dataset names are detected in the provided datasets. """ def __init__( self, - datasets: List[TuneIterableDataset], - weights: List[float], + datasets: list[TuneIterableDataset], + weights: list[float], seed: int, dataset_name: str = "interleaved_dataset", ): @@ -36,7 +45,7 @@ def __init__( self._dataset_names = [ds.dataset_name for ds in datasets] # Create a name-to-dataset mapping for robust state management - self._datasets: Dict[str, TuneIterableDataset] = { + self._datasets: dict[str, TuneIterableDataset] = { ds.dataset_name: ds for ds in datasets } @@ -54,21 +63,22 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) # Normalize weights to sum to 1 - #TODO: make it a property? rely on ds.weight? + # TODO: make it a property? rely on ds.weight? total_weight = sum(weights) self._weights = torch.tensor( [w / total_weight for w in weights], dtype=torch.float ) if not math.isclose(total_weight, 1.0, rel_tol=1e-9): logger.warning( - f"Interleaved dataset normalized weights to sum to 1.0. Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + f"Interleaved dataset normalized weights to sum to 1.0. " + f"Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" ) @property def dataset_name(self) -> str: return self._dataset_name - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" child_iters = {name: iter(ds) for name, ds in self._datasets.items()} @@ -96,7 +106,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: sample = next(child_iters[ds_name]) yield sample - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Save state for the interleaver and its children.""" # The parent is responsible for namespacing the child states. child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} @@ -105,11 +115,11 @@ def state_dict(self) -> Dict[str, Any]: "child_states": child_states, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state for the interleaver and its children.""" self._sampling_generator.set_state(state_dict["sampling_generator_state"]) child_states = state_dict["child_states"] for name, ds in self._datasets.items(): if name in child_states: # Pass the raw state dict to the child - ds.load_state_dict(child_states[name]) \ No newline at end of file + ds.load_state_dict(child_states[name]) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 725810541c..9dac9ee0b1 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -1,5 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator +from typing import Any, dict, Iterator from torch.utils.data import IterableDataset @@ -19,7 +25,7 @@ def dataset_name(self) -> str: pass @abstractmethod - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """ Returns an infinite iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling and making it an infinite stream. @@ -27,11 +33,11 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: pass @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Returns a state dictionary for checkpointing""" pass @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" - pass \ No newline at end of file + pass diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 70bfb75fd5..04f78a9911 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional, Dict +from typing import Any, Callable, Mapping, Optional import numpy as np from datasets import load_dataset @@ -12,11 +12,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages - -from torchtune.modules.transforms import Transform from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset +from torchtune.modules.transforms import Transform + class SFTDataset(Dataset): """ @@ -187,24 +187,24 @@ class SFTOutputTransform(Transform): Output transform to be used in SFT recipes as an input to TuneIterableDataset. It takes tokenized inputs with "tokens" and "mask" keys and creates the "labels" key for SFT training. - + The labels are created by: 1. Shifting tokens by 1 position (for autoregressive training) 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end """ - + def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: # Create a copy to avoid modifying the original tokenized_dict = dict(sample) - + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) raise ValueError( f"SFTOutputTransform expects 'tokens' and 'mask' keys. " f"Got keys: {keys_str}" ) - + # Create labels for SFT training tokenized_dict["labels"] = list( np.where( @@ -215,12 +215,12 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: ) tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) - + return tokenized_dict def sft_iterable_dataset( - model_transform: Transform, + model_transform: Transform, *, message_transform: Transform, shuffle_buffer_size: Optional[int] = 1000, @@ -228,26 +228,26 @@ def sft_iterable_dataset( num_shards_per_rank: int = 64, dataset_name: Optional[str] = None, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, - **load_dataset_kwargs: Dict[str, Any], + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ Creates an SFT-ready iterable dataset with appropriate output transform. - + Args: model_transform (Transform): Usually the tokenizer message_transform (Transform): Transform to convert raw data to messages - shuffle_buffer_size (Optional[int]): Buffer size for shuffling + shuffle_buffer_size (Optional[int]): Buffer size for shuffling seed (int): Random seed for shuffling num_shards_per_rank (int): Target shards per worker dataset_name (Optional[str]): Name for metrics namespacing filter_fn (Optional[Callable]): Filter function - filter_kwargs (Optional[Dict[str, Any]]): Filter function kwargs - **load_dataset_kwargs: Args passed to load_dataset - + filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs + **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset + Returns: HfIterableDataset: Configured for SFT training - + Example: >>> from torchtune.data import AlpacaToMessages >>> message_transform = AlpacaToMessages(train_on_input=False) @@ -259,11 +259,11 @@ def sft_iterable_dataset( """ output_transform = SFTOutputTransform() - + return HfIterableDataset( message_transform=message_transform, model_transform=model_transform, - output_transform=output_transform, + output_transform=output_transform, metric_transform=StandardMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, seed=seed, diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index 77667aa579..5a5e9bc94f 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -7,11 +7,10 @@ from typing import Any, Callable, Optional, Union from torchtune.data import ShareGPTToMessages -from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -121,7 +120,7 @@ def slimorca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". column_map (Optional[dict[str, str]]): mapping from expected "conversations" column to actual column name in dataset. If None, uses default "conversations". train_on_input (bool): Whether to train on input or mask it. Default is False. @@ -132,13 +131,13 @@ def slimorca_iterable_dataset( dataset_name (Optional[str]): Name for metrics. If None, auto-generated from source. filter_fn (Optional[Callable]): Filter function to apply to dataset. filter_kwargs (Optional[dict[str, Any]]): Kwargs for filter function. - **load_dataset_kwargs: Additional kwargs for load_dataset. + **load_dataset_kwargs (dict[str, Any]): Additional kwargs for load_dataset. Returns: HfIterableDataset: Configured iterable dataset Example: - >>> from torchtune.datasets import slimorca_iterable_dataset + >>> from torchtune.datasets import slimorca_iterable_dataset >>> ds = slimorca_iterable_dataset(shuffle_buffer_size=1000) >>> for sample in ds: >>> print(sample["tokens"][:10]) # First 10 tokens @@ -148,9 +147,9 @@ def slimorca_iterable_dataset( column_map=column_map, new_system_prompt=new_system_prompt, ) - + return sft_iterable_dataset( - source=source, + source=source, message_transform=message_transform, model_transform=model_transform, shuffle_buffer_size=shuffle_buffer_size, diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 5f29e038c6..05fd46e395 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -19,6 +19,7 @@ StateDictOptions, ) from torchtune import config, training, utils +from torchtune.data import MetricsAggregator from torchtune.modules.optim import OptimizerInBackward from torchtune.modules.peft import ( get_adapter_state_dict, @@ -28,7 +29,6 @@ from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer from torchtune.training.checkpointing._utils import get_most_recent_checkpoint from torchtune.training.memory import OptimizerInBackwardWrapper -from torchtune.data import MetricsAggregator log = utils.get_logger("DEBUG") import torchdata @@ -463,7 +463,9 @@ def load_distributed_checkpoint( checkpoint_dict: dict[str, Any] = {} model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() - metrics_aggregator_state_dict = metrics_aggregator.state_dict() if metrics_aggregator else {} + metrics_aggregator_state_dict = ( + metrics_aggregator.state_dict() if metrics_aggregator else {} + ) # Hack to properly initialize the learning rate scheduler # TODO: Find a better way to do this, possibly by including the following From 2e51e04f01150a251f1defb49de6991e1d8f8256 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 13:56:58 -0700 Subject: [PATCH 17/48] tests pass --- tests/torchtune/datasets/test_hf_iterable.py | 170 +++--------------- tests/torchtune/datasets/test_interleaved.py | 3 +- .../torchtune/datasets/test_iterable_utils.py | 25 ++- torchtune/datasets/_interleaved.py | 2 +- torchtune/datasets/_iterable_base.py | 2 +- 5 files changed, 53 insertions(+), 149 deletions(-) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index a263258ae8..83144f0ae9 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -6,18 +6,15 @@ from itertools import islice from pathlib import Path -from typing import Any, Optional import pytest + from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import ( - MetricsAggregator, - padded_collate_sft, - StandardMetricTransform, -) +from torchtune.data import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset +from .test_iterable_utils import collate_with_metrics, generate_ckpt # Test Constants - Avoid perfect divisions SMALL_DATASET_SIZE = 23 @@ -42,129 +39,10 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None token_len = (i % 3) + 1 tokens = list(range(sample_id, sample_id + token_len)) f.write( - f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' ) -def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" - # Extract metrics first - all_metrics = [] - clean_batch = [] - for sample in batch: - if "metrics" in sample: - all_metrics.extend(sample.pop("metrics")) - clean_batch.append(sample) - - if not clean_batch: - return {"metrics": all_metrics} - - # Use torchtune's padded_collate_sft as base collator - collated_batch = padded_collate_sft(clean_batch) - collated_batch["metrics"] = all_metrics - return collated_batch - - -def generate_ckpt( - dataloader: StatefulDataLoader, - aggregator: MetricsAggregator, - steps_before_checkpoint: int, - steps_after_checkpoint: int, - resume_dataloader: Optional[StatefulDataLoader] = None, - resume_aggregator: Optional[MetricsAggregator] = None, -) -> dict[str, Any]: - """ - Generates a checkpoint by running through data and saving checkpoint mid-stream. - Optionally, a second dataloader and aggregator can be given to resume from ckpt - and run steps_after_checkpoint to match the first one. - - Args: - dataloader (StatefulDataLoader): The dataloader to test - aggregator (MetricsAggregator): The metrics aggregator to use - steps_before_checkpoint (int): Number of steps to run before saving checkpoint - steps_after_checkpoint (int): Number of steps to run after checkpoint - resume_dataloader (Optional[StatefulDataLoader]): Optional new dataloader to test resuming. - If None, returns empty resumed_batches. - resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. - If None, returns empty resumed_metrics. - - Returns: - dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. - """ - iterator = iter(dataloader) - - # Collect batches before and after checkpoint - batches = [] - checkpoint_state = None - metrics_at_checkpoint = {} - - total_steps = steps_before_checkpoint + steps_after_checkpoint - - for idx, batch in enumerate(iterator): - batches.append(batch) - - # Process metrics - if "metrics" in batch: - aggregator.update(batch.pop("metrics")) - - # Save checkpoint state after steps_before_checkpoint - if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based - checkpoint_state = { - "loader": dataloader.state_dict(), - "aggregator": aggregator.state_dict(), - } - metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") - - # Stop after total steps - if idx == total_steps - 1: - break - - # Split batches - pre_checkpoint_batches = batches[:steps_before_checkpoint] - post_checkpoint_batches = batches[steps_before_checkpoint:] - - # Resume with new instances if provided - resumed_batches = [] - resumed_metrics = {} - - if ( - resume_dataloader is not None - and resume_aggregator is not None - and checkpoint_state is not None - ): - # Test resuming with new instances - resume_dataloader.load_state_dict(checkpoint_state["loader"]) - resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) - resume_iterator = iter(resume_dataloader) - - # Collect only the post-checkpoint batches when resuming - for idx, batch in enumerate(resume_iterator): - resumed_batches.append(batch) - - # Process metrics - if "metrics" in batch: - resume_aggregator.update(batch.pop("metrics")) - - # Stop after steps_after_checkpoint - if idx == steps_after_checkpoint - 1: - break - - resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") - - return { - # Original run - "pre_checkpoint_batches": pre_checkpoint_batches, - "post_checkpoint_batches": post_checkpoint_batches, - "metrics_at_checkpoint": metrics_at_checkpoint, - "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), - # Resumed run - "resumed_batches": resumed_batches, - "resumed_metrics": resumed_metrics, - # Internal state for loading - only if someone needs to manually load - "_checkpoint_state": checkpoint_state, - } - - @pytest.fixture def tmp_data_dir(tmp_path): """Provide temporary directory for test data files.""" @@ -292,14 +170,16 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Get samples from two passes through the dataset - epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # Unshuffled should have same order in both epochs - assert first_epoch_samples == list(range(SMALL_DATASET_SIZE)) - assert second_epoch_samples == list(range(SMALL_DATASET_SIZE)) + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + assert first_epoch_ids == list(range(SMALL_DATASET_SIZE)) + assert second_epoch_ids == list(range(SMALL_DATASET_SIZE)) # Test shuffled dataset shuffled_ds = dataset_factory( @@ -307,48 +187,56 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Collect full epochs to compare - epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # Shuffled epochs should have different order - assert first_epoch_samples != list( + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + assert first_epoch_ids != list( range(SMALL_DATASET_SIZE) - ), f"Shuffled should not be sorted, got {first_epoch_samples}" + ), f"Shuffled should not be sorted, got {first_epoch_ids}" assert ( - first_epoch_samples != second_epoch_samples - ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" + first_epoch_ids != second_epoch_ids + ), f"Shuffled epochs should be shuffled differently, got {first_epoch_ids} and {second_epoch_ids}" # But should contain the same set of IDs - assert set(first_epoch_samples) == set( + assert set(first_epoch_ids) == set( range(SMALL_DATASET_SIZE) - ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" - assert set(second_epoch_samples) == set( + ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}" + assert set(second_epoch_ids) == set( range(SMALL_DATASET_SIZE) - ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}" def test_epoch_tracking(self, dataset_factory, small_dataset_file): """Test that epoch number is correctly tracked across dataset restarts.""" dataset = dataset_factory(small_dataset_file, shuffle=False) # Two epoch samples - epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(dataset), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # All should have epoch 0 + first_epoch_metrics = [] + for sample in first_epoch_samples: + first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - epoch_metric.value for epoch_metric in first_epoch_samples["metrics"] + metric.value for metric in first_epoch_metrics if metric.name == "epoch" ] assert all( epoch_value == 0 for epoch_value in epoch_values ), f"Epoch values should be 0, got {epoch_values}" # All should have epoch 1 + second_epoch_metrics = [] + for sample in second_epoch_samples: + second_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - epoch_metric.value for epoch_metric in second_epoch_samples["metrics"] + metric.value for metric in second_epoch_metrics if metric.name == "epoch" ] assert all( epoch_value == 1 for epoch_value in epoch_values diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index e06ef670c1..96139fd868 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest + import torch from torchdata.stateful_dataloader import StatefulDataLoader @@ -40,7 +41,7 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None token_len = (i % 3) + 1 tokens = list(range(sample_id, sample_id + token_len)) f.write( - f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' ) diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index 5a4fbda8a5..e160345bc1 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -6,13 +6,14 @@ from typing import Any, Optional +import torch + from torch.utils.data import DataLoader -from torchtune.data import padded_collate_sft -from torchtune.data._metrics import MetricsAggregator +from torchtune.data import MetricsAggregator def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" + """Simple collate that extracts metrics and pads tokens.""" all_metrics = [] clean_batch = [] for sample in batch: @@ -23,8 +24,22 @@ def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: if not clean_batch: return {"metrics": all_metrics} - # Use torchtune's standard SFT collate function - collated = padded_collate_sft(clean_batch) + # Simple padding for tokens + ids = torch.tensor([item["id"] for item in clean_batch]) + tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(item["tokens"]) for item in clean_batch], + batch_first=True, + padding_value=-1, # Use -1 for padding to distinguish from valid IDs + ) + collated = { + "id": ids, + "tokens": tokens, + } + + # Add text field for non-tensor data + if "text" in clean_batch[0]: + collated["text"] = [item["text"] for item in clean_batch] + collated["metrics"] = all_metrics return collated diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 53185993dd..0245d4e94e 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -7,7 +7,7 @@ import collections import logging import math -from typing import Any, dict, Iterator +from typing import Any, Iterator import torch diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 9dac9ee0b1..f0821dc3f1 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, dict, Iterator +from typing import Any, Iterator from torch.utils.data import IterableDataset From 93fa7436aa4ad294b6aea813c8d69528494e1d5c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 07:16:11 -0700 Subject: [PATCH 18/48] it works --- recipes/configs/llama3_2/3B_full.yaml | 10 ++-- recipes/full_finetune_distributed.py | 56 +++++++++---------- .../torchtune/data/test_metrics_aggregator.py | 46 +++++++-------- tests/torchtune/datasets/test_interleaved.py | 16 +++--- torchtune/data/_aggregator.py | 19 ++++--- torchtune/datasets/_alpaca.py | 6 +- torchtune/datasets/_sft.py | 3 + torchtune/datasets/_slimorca.py | 10 +++- 8 files changed, 86 insertions(+), 80 deletions(-) diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 5534b305ac..f825e9194e 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -28,7 +28,7 @@ tokenizer: # Dataloader dataloader: - batch_size: 4 + batch_size: 16 # num_workers and pin_memory can be added here if needed # Dataset - now a list to support multiple weighted sources @@ -36,16 +36,18 @@ dataset: - _component_: torchtune.datasets.slimorca_iterable_dataset shuffle_buffer_size: 1000 weight: 0.8 + split: train[:5%] # simular 1 epoch quickly - _component_: torchtune.datasets.alpaca_iterable_dataset shuffle_buffer_size: 1000 weight: 0.2 + split: train[:5%] # simular 1 epoch quickly # Packing (TBD by follow up PR) # packing: # _component_: torchtune.datasets.packing.SFTPacking # max_seq_len: 8192 -seed: null +seed: 42 # Validation not supported yet with iterable datasets @@ -76,7 +78,7 @@ loss: num_training_steps: 100 # Total number of training steps to run save_every_n_steps: 200 # Save a checkpoint every N steps. Using 200 to avoid ckpt. gradient_accumulation_steps: 1 -dataset_metrics_log_freq: 10 # Log dataset-specific metrics every N steps +dataset_metrics_log_freq: 5 # Log dataset-specific metrics every N steps # Environment device: cuda @@ -91,7 +93,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.WandBLogger log_dir: ${output_dir}/logs log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index a4eb87e7d6..5b65a6e9f5 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -13,7 +13,7 @@ from warnings import warn import torch -from omegaconf import dictConfig, listConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -119,7 +119,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): has example commands for how to kick-off training. Args: - cfg (dictConfig): OmegaConf object parsed from yaml file + cfg (DictConfig): OmegaConf object parsed from yaml file Raises: ValueError: If ``dtype`` is set to fp16. @@ -129,7 +129,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ - def __init__(self, cfg: dictConfig) -> None: + def __init__(self, cfg: DictConfig) -> None: device_type = cfg.device self._device = utils.get_device(device=device_type) self._dtype = training.get_dtype(cfg.dtype, device=self._device) @@ -303,7 +303,7 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: "Are you sure you passed in the right recipe checkpoint?" ) from e - def setup(self, cfg: dictConfig) -> None: + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. @@ -334,7 +334,7 @@ def setup(self, cfg: dictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool - if isinstance(compile, dictConfig): + if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) @@ -431,7 +431,7 @@ def setup(self, cfg: dictConfig) -> None: collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, - batch_size=cfg.batch_size, + cfg_dataloader=cfg.dataloader, collate_fn=collate_name, dataloader_state_dict=( state_dict[training.DATALOADER_KEY] @@ -443,10 +443,9 @@ def setup(self, cfg: dictConfig) -> None: # Setup validation dataloader if validation dataset is provided self._val_dataloader = None if cfg.get("dataset_val") is not None: - batch_size_val = cfg.get("batch_size_val", cfg.batch_size) self._val_dataloader = self._setup_data( cfg_dataset=cfg.dataset_val, - batch_size=batch_size_val, + cfg_dataloader=cfg.get("dataloader_val", None), collate_fn=collate_name, dataloader_state_dict=( state_dict[training.VAL_DATALOADER_KEY] @@ -471,7 +470,7 @@ def setup(self, cfg: dictConfig) -> None: def _setup_lr_scheduler( self, - cfg_lr_scheduler: Optional[dictConfig], + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, ) -> Optional[Optimizer]: @@ -480,7 +479,7 @@ def _setup_lr_scheduler( It supports both standard optimization and optimizer-in-backward cases. Args: - cfg_lr_scheduler (Optional[dictConfig]): The learning rate scheduler configuration. + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. num_training_steps (int): The total number of training steps. last_epoch (int): The index of the last epoch. @@ -519,14 +518,14 @@ def _setup_lr_scheduler( return lr_scheduler def _setup_profiler( - self, cfg_profiler: Optional[dictConfig] = None + self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler """ # Missing profiler section in config, assume disabled if cfg_profiler is None: - cfg_profiler = dictConfig({"enabled": False}) + cfg_profiler = DictConfig({"enabled": False}) # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: @@ -553,7 +552,7 @@ def _setup_profiler( def _setup_model( self, - cfg_model: dictConfig, + cfg_model: DictConfig, enable_activation_checkpointing: bool, enable_activation_offloading: bool, activation_offloading_use_streams: bool, @@ -711,7 +710,7 @@ def _setup_model( def _setup_optimizer( self, - cfg_optimizer: dictConfig, + cfg_optimizer: DictConfig, optimizer_in_bwd: bool = False, opt_state_dict: Optional[dict[str, Any]] = None, ) -> Optional[Optimizer]: @@ -764,8 +763,8 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: Union[dictConfig, listConfig], - batch_size: int, + cfg_dataset: Union[DictConfig, ListConfig], + cfg_dataloader: DictConfig, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: @@ -777,7 +776,7 @@ def _setup_data( iterable_datasets = [] weights = [] cfg_dataset_list = cfg_dataset - if not isinstance(cfg_dataset_list, listConfig): + if not isinstance(cfg_dataset_list, ListConfig): cfg_dataset_list = [cfg_dataset_list] for ds_cfg in cfg_dataset_list: @@ -823,8 +822,8 @@ def _collate_with_metrics_wrapper( # 5. Create DataLoader dataloader = StatefulDataLoader( dataset=ds, - batch_size=batch_size, collate_fn=_collate_with_metrics_wrapper, + **cfg_dataloader, ) if dataloader_state_dict is not None: @@ -898,7 +897,7 @@ def validate(self) -> dict[str, float]: self._model.train() return log_dict - def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): + def save_checkpoint(self, *, epoch: int, full_tensors: bool): """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, @@ -924,7 +923,6 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), epoch=epoch, # TODO: not needed. To be deprecated. - step=step, single_device=False, full_tensors=full_tensors, dir_prefix=self.checkpoint_dir_prefix, @@ -1074,23 +1072,21 @@ def train(self) -> None: # Log dataset metrics # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if ( - self.global_step % self._dataset_metrics_log_freq == 0 - and self._is_rank_zero - ): + if self.global_step % self._dataset_metrics_log_freq == 0: dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) - self._metric_logger.log_dict(dataset_metrics, step=self.global_step) + if self._is_rank_zero: + self._metric_logger.log_dict( + dataset_metrics, step=self.global_step + ) # Save checkpoint if specified by user if ( self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0 ): - self.save_checkpoint( - epoch=0, step=self.global_step, full_tensors=False - ) + self.save_checkpoint(epoch=0, full_tensors=False) # Reset running stats for the next step running_loss = 0 @@ -1121,7 +1117,7 @@ def train(self) -> None: self.validate() self._profiler.stop() - self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=True) + self.save_checkpoint(epoch=0, full_tensors=True) def cleanup(self) -> None: if self._is_rank_zero: @@ -1130,7 +1126,7 @@ def cleanup(self) -> None: @config.parse -def recipe_main(cfg: dictConfig) -> None: +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 69a2f32967..a9fda513a4 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -36,13 +36,13 @@ def test_aggregation_types(self, agg_type, test_values, expected): ] aggregator.update(metrics) - result = aggregator.get_metrics_for_logging() + result = aggregator.get_metrics_for_logging(prefix="train") if agg_type == AggregationType.CATEGORICAL_COUNT: for category, count in expected.items(): - assert result[f"test/metric_{category}_count"] == count + assert result[f"train_test/metric_{category}_count"] == count else: - assert result["test/metric"] == expected + assert result["train_test/metric"] == expected def test_distribution_metrics(self): """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" @@ -58,11 +58,11 @@ def test_distribution_metrics(self): result = aggregator.get_metrics_for_logging(prefix="train") # Verify distribution statistics - assert result["train/test/dist_metric_mean"] == 5.5 - assert result["train/test/dist_metric_min"] == 1 - assert result["train/test/dist_metric_max"] == 10 + assert result["train_test/dist_metric_mean"] == 5.5 + assert result["train_test/dist_metric_min"] == 1 + assert result["train_test/dist_metric_max"] == 10 assert ( - result["train/test/dist_metric_p50"] == 5 + result["train_test/dist_metric_p50"] == 5 ) # Median of 1-10 is 5 (index 4, value 5) def test_state_management(self): @@ -84,8 +84,8 @@ def test_state_management(self): aggregator2.load_state_dict(state) # Both should have identical metrics - metrics1 = aggregator1.get_metrics_for_logging() - metrics2 = aggregator2.get_metrics_for_logging() + metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + metrics2 = aggregator2.get_metrics_for_logging(prefix="train") assert metrics1 == metrics2 # Continue updating both - should remain identical @@ -96,13 +96,13 @@ def test_state_management(self): aggregator1.update(additional_metrics) aggregator2.update(additional_metrics) - final_metrics1 = aggregator1.get_metrics_for_logging() - final_metrics2 = aggregator2.get_metrics_for_logging() + final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train") assert final_metrics1 == final_metrics2 # Verify expected values - assert final_metrics1["ds1/counter"] == 15 # 10 + 5 - assert final_metrics1["ds1/average"] == 10.0 # (5 + 15) / 2 + assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2 def test_multiple_datasets(self): """Test that metrics from multiple datasets are correctly namespaced.""" @@ -118,15 +118,15 @@ def test_multiple_datasets(self): result = aggregator.get_metrics_for_logging(prefix="train") - assert result["train/dataset1/samples"] == 100 - assert result["train/dataset2/samples"] == 200 - assert result["train/dataset1/tokens"] == 1000 - assert result["train/dataset2/tokens"] == 2000 + assert result["train_dataset1/samples"] == 100 + assert result["train_dataset2/samples"] == 200 + assert result["train_dataset1/tokens"] == 1000 + assert result["train_dataset2/tokens"] == 2000 def test_empty_aggregator(self): """Test that empty aggregator returns empty metrics.""" aggregator = MetricsAggregator() - result = aggregator.get_metrics_for_logging() + result = aggregator.get_metrics_for_logging(prefix="train") assert result == {} def test_prefix_handling(self): @@ -140,10 +140,10 @@ def test_prefix_handling(self): # Test with prefix result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") - assert result_with_prefix["validation/test_ds/metric1"] == 42 - assert result_with_prefix["validation/test_ds/metric2"] == 84 + assert result_with_prefix["validation_test_ds/metric1"] == 42 + assert result_with_prefix["validation_test_ds/metric2"] == 84 - # Test without prefix + # Test without prefix (uses default "data") result_no_prefix = aggregator.get_metrics_for_logging() - assert result_no_prefix["test_ds/metric1"] == 42 - assert result_no_prefix["test_ds/metric2"] == 84 + assert result_no_prefix["data_test_ds/metric1"] == 42 + assert result_no_prefix["data_test_ds/metric2"] == 84 diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 96139fd868..e1ded110ac 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -168,25 +168,25 @@ def test_metrics_aggregation( for sample in islice(iter(interleaved), total_samples): aggregator.update(sample["metrics"]) - metrics = aggregator.get_metrics_for_logging() + metrics = aggregator.get_metrics_for_logging(prefix="train") # Should have metrics from both datasets, with flat keys - assert "ds1/samples_seen" in metrics - assert "ds2/samples_seen" in metrics + assert "train_ds1/samples_seen" in metrics + assert "train_ds2/samples_seen" in metrics # Both datasets should have contributed samples - assert metrics["ds1/samples_seen"] > 0 - assert metrics["ds2/samples_seen"] > 0 + assert metrics["train_ds1/samples_seen"] > 0 + assert metrics["train_ds2/samples_seen"] > 0 # Total samples should equal what we processed calculated_total_samples = ( - metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] + metrics["train_ds1/samples_seen"] + metrics["train_ds2/samples_seen"] ) assert calculated_total_samples == total_samples # Test that ratio is approximately correct - ds1_ratio = metrics["ds1/samples_seen"] / total_samples - ds2_ratio = metrics["ds2/samples_seen"] / total_samples + ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples + ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples # Allow 10% tolerance due to randomness assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index 9c933ddfa3..313c5bdc73 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -92,7 +92,7 @@ def _initialize_state( elif agg_type == AggregationType.CATEGORICAL_COUNT: state["counts"] = collections.Counter() - def get_metrics_for_logging(self, prefix: str = "") -> dict[str, float]: + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: """ Returns aggregated metrics ready for logging to wandb/tensorboard. @@ -237,7 +237,6 @@ def _compute_distributed_metrics( world_size = dist.get_world_size() # Gather all metrics from all ranks in one operation - dist.barrier() all_metrics = [None] * world_size dist.all_gather_object(all_metrics, local_metrics) @@ -276,7 +275,10 @@ def _compute_distributed_metrics( return reduced def _format_for_logging( - self, metrics: dict[tuple[str, str], dict[str, Any]], prefix: str + self, + metrics: dict[tuple[str, str], dict[str, Any]], + prefix: str, + template: str = r"{prefix}_{ds_name}/{metric_name}", ) -> dict[str, float]: """ Format metrics for wandb/tensorboard logging. @@ -285,6 +287,7 @@ def _format_for_logging( metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} prefix (str): Optional prefix like "train" or "valid" + template (str): Template for metric key. Use {prefix}, {ds_name}, and {metric_name} as placeholders. Returns: dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float @@ -292,12 +295,10 @@ def _format_for_logging( formatted = {} for (ds_name, metric_name), metric_dict in metrics.items(): - # Build key: "prefix/dataset/metric" or "dataset/metric" if no prefix - if prefix: - key = f"{prefix}/{ds_name}/{metric_name}" - else: - key = f"{ds_name}/{metric_name}" - + # Use regex format to build key + key = template.format( + prefix=prefix, ds_name=ds_name, metric_name=metric_name + ) formatted[key] = metric_dict["value"] return formatted diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index bae7613729..4326b7024f 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -107,7 +107,7 @@ def alpaca_dataset( def alpaca_iterable_dataset( model_transform: ModelTokenizer, *, - source: str = "tatsu-lab/alpaca", + path: str = "tatsu-lab/alpaca", column_map: Optional[dict[str, str]] = None, train_on_input: bool = True, shuffle_buffer_size: Optional[int] = 1000, @@ -125,7 +125,7 @@ def alpaca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. + path (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. column_map (Optional[dict[str, str]]): a mapping from the expected columns in the message transform :class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be "instruction", "input", and "output" and values should be the actual column names. @@ -160,6 +160,6 @@ def alpaca_iterable_dataset( dataset_name=dataset_name, filter_fn=filter_fn, split=split, - path=source, + path=path, **load_dataset_kwargs, ) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 04f78a9911..6dabee9bb6 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -222,6 +222,7 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: def sft_iterable_dataset( model_transform: Transform, *, + weight: int = 1, message_transform: Transform, shuffle_buffer_size: Optional[int] = 1000, seed: int = 42, @@ -236,6 +237,7 @@ def sft_iterable_dataset( Args: model_transform (Transform): Usually the tokenizer + weight (int): Weight of the dataset. Used for sampling when interleaving datasets. message_transform (Transform): Transform to convert raw data to messages shuffle_buffer_size (Optional[int]): Buffer size for shuffling seed (int): Random seed for shuffling @@ -266,6 +268,7 @@ def sft_iterable_dataset( output_transform=output_transform, metric_transform=StandardMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, + weight=weight, seed=seed, num_shards_per_rank=num_shards_per_rank, dataset_name=dataset_name, diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index 5a5e9bc94f..0346e5b73a 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -100,7 +100,8 @@ def slimorca_dataset( def slimorca_iterable_dataset( model_transform: ModelTokenizer, *, - source: str = "Open-Orca/SlimOrca-Dedup", + path: str = "Open-Orca/SlimOrca-Dedup", + split: str = "train", column_map: Optional[dict[str, str]] = None, train_on_input: bool = False, new_system_prompt: Optional[str] = None, @@ -120,7 +121,9 @@ def slimorca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + path (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a + subset of a given split, e.g. ``split="train[:10%]"``. Default is "train column_map (Optional[dict[str, str]]): mapping from expected "conversations" column to actual column name in dataset. If None, uses default "conversations". train_on_input (bool): Whether to train on input or mask it. Default is False. @@ -149,7 +152,8 @@ def slimorca_iterable_dataset( ) return sft_iterable_dataset( - source=source, + path=path, + split=split, message_transform=message_transform, model_transform=model_transform, shuffle_buffer_size=shuffle_buffer_size, From aa9e6f417bd82af6fa78cbd8b1e2233751ab6981 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 07:46:10 -0700 Subject: [PATCH 19/48] remove code --- torchtune/datasets/_hf_iterable.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 0be4c0cc53..4b7c04d04c 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -201,12 +201,11 @@ def __iter__(self) -> Iterator[dict[str, Any]]: An additional metric "num_epochs" is added to the sample. """ - epoch_ds = self._ds while True: # Infinite iteration epoch_seed = self._seed + self._num_epochs - epoch_ds.set_epoch(epoch_seed) - epoch_iterator = iter(epoch_ds) + self._ds.set_epoch(epoch_seed) + epoch_iterator = iter(self._ds) samples_yielded = 0 try: @@ -249,9 +248,6 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Epoch complete - increment and continue infinite loop self._num_epochs += 1 - # Reset to the base dataset for the next epoch's shuffling. - epoch_ds = self._ds - def state_dict(self) -> dict[str, Any]: """ The dataset returns its own state directly, without namespacing. From 55be7756e0fd03b493dde46691925825f5cb3948 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 13:55:38 -0700 Subject: [PATCH 20/48] adjust pack to have metrics --- torchtune/datasets/_iterable_packed.py | 81 +++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 9 deletions(-) diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 173ba5a908..a788da8db3 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -15,6 +15,7 @@ ) from torchdata.stateful_dataloader import Stateful from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.data._metrics import AggregationType, Metric from torchtune.datasets import TuneIterableDataset from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) SampleType = TypeVar("SampleType") -PackType = dict[str, torch.Tensor] +PackType = dict[str, torch.Tensor | list[Metric]] class PackingStrategy(ABC, Generic[SampleType]): @@ -39,6 +40,16 @@ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX) self.padding_idx = padding_idx self.ignore_idx = ignore_idx + @abstractmethod + def set_dataset_name(self, dataset_name: str) -> None: + """ + Sets the dataset name on the strategy. + + Args: + dataset_name (str): The name of the dataset. + """ + pass + @abstractmethod def create_empty_pack(self) -> dict[str, list[Any]]: """ @@ -188,6 +199,7 @@ class IterablePackedDataset( strategy (PackingStrategy[SampleType]): The PackingStrategy to use for packing. target_tokens_per_pack (int): The target number of tokens per pack. buffer_size (int): The size of the buffer to use for packing. + dataset_name (str): The name of the dataset. If None, a defaults to IterablePackedDataset. """ def __init__( @@ -196,14 +208,23 @@ def __init__( strategy: PackingStrategy[SampleType], target_tokens_per_pack: int, buffer_size: int = 50, + dataset_name: str = "IterablePackedDataset", ): self.dataset = dataset self.strategy = strategy self.target_tokens_per_pack = target_tokens_per_pack self.buffer_size = buffer_size + self._dataset_name = dataset_name + + # Set dataset name on strategy if it supports it + self.strategy.set_dataset_name(dataset_name) self._reset_packer_state() + @property + def dataset_name(self) -> str: + return self._dataset_name + def _reset_packer_state(self) -> None: """Resets the packer's internal state for a new or resumed iteration.""" # buffer: deque of (sample, size) tuples that have not been added to a pack yet @@ -400,20 +421,26 @@ class TextPackingStrategy(PackingStrategy[dict[str, list[int]]]): def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + self.dataset_name = "packed_dataset" # Default name + + def set_dataset_name(self, dataset_name: str) -> None: + """Set the dataset name for metrics.""" + self.dataset_name = dataset_name - def create_empty_pack(self) -> dict[str, list[int]]: + def create_empty_pack(self) -> dict[str, list]: return { "tokens": [], "labels": [], "document_ids": [], "input_pos": [], + "metrics": [], } def get_sample_size(self, sample: dict[str, list[int]]) -> int: return len(sample["tokens"]) def add_sample_to_pack( - self, pack: dict[str, list[int]], sample: dict[str, list[int]], next_doc_id: int + self, pack: dict[str, list], sample: dict[str, list[int]], next_doc_id: int ) -> int: seq_len = len(sample["tokens"]) @@ -423,11 +450,15 @@ def add_sample_to_pack( pack["document_ids"].extend([next_doc_id] * seq_len) pack["input_pos"].extend(range(seq_len)) # input_pos restarts for each doc + # Handle metrics if they exist in the sample + if "metrics" in sample: + pack["metrics"].extend(sample["metrics"]) + # Increment doc ID for the next sample return 1 def finalize_pack( - self, pack: dict[str, list[int]], target_tokens_per_pack: int, next_doc_id: int + self, pack: dict[str, list], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: current_size = len(pack["tokens"]) num_padding = target_tokens_per_pack - current_size @@ -438,13 +469,25 @@ def finalize_pack( pack["input_pos"].extend([0] * num_padding) pack["document_ids"].extend([next_doc_id] * num_padding) - return { + # Add pct_of_tokens_padded metric + padding_metric = Metric( + dataset_name=self.dataset_name, + name="pct_of_tokens_padded", + value=round(num_padding * 100 / len(pack["tokens"]), 2), + agg_type=AggregationType.MEAN, + ) + pack["metrics"].append(padding_metric) + + result = { "tokens": torch.tensor(pack["tokens"], dtype=torch.long), "labels": torch.tensor(pack["labels"], dtype=torch.long), "document_ids": torch.tensor(pack["document_ids"], dtype=torch.long), "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), + "metrics": pack["metrics"], } + return result + def _mask_mod( self, b: int, @@ -483,8 +526,13 @@ class DPOPackingStrategy(PackingStrategy[dict[str, list[int]]]): def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + self.dataset_name = "packed_dataset" # Default name - def create_empty_pack(self) -> dict[str, list[int]]: + def set_dataset_name(self, dataset_name: str) -> None: + """Set the dataset name for metrics.""" + self.dataset_name = dataset_name + + def create_empty_pack(self) -> dict[str, list]: return { "tokens": [], "labels": [], @@ -492,6 +540,7 @@ def create_empty_pack(self) -> dict[str, list[int]]: "input_pos": [], "chosen_response_mask": [], "rejected_response_mask": [], + "metrics": [], } def get_sample_size(self, sample: dict[str, list[int]]) -> int: @@ -503,7 +552,7 @@ def get_sample_size(self, sample: dict[str, list[int]]) -> int: ) def add_sample_to_pack( - self, pack: dict[str, list[int]], sample: dict[str, list[int]], next_doc_id: int + self, pack: dict[str, list], sample: dict[str, list[int]], next_doc_id: int ) -> int: # Assign a unique doc ID triplet for (prompt, chosen, rejected) prompt_doc_id = next_doc_id @@ -539,12 +588,16 @@ def add_sample_to_pack( pack["chosen_response_mask"].extend([False] * len(rejected_ids)) pack["rejected_response_mask"].extend([True] * len(rejected_ids)) + # Handle metrics if they exist in the sample + if "metrics" in sample: + pack["metrics"].extend(sample["metrics"]) + # Advance the document ID counter by 3 for the next DPO sample. return 3 def finalize_pack( - self, pack: dict[str, list[int]], target_tokens_per_pack: int, next_doc_id: int - ) -> dict[str, torch.Tensor]: + self, pack: dict[str, list], target_tokens_per_pack: int, next_doc_id: int + ) -> PackType: current_size = len(pack["tokens"]) num_padding = target_tokens_per_pack - current_size @@ -556,6 +609,15 @@ def finalize_pack( pack["rejected_response_mask"].extend([False] * num_padding) pack["document_ids"].extend([next_doc_id] * num_padding) + # Add pct_of_tokens_padded metric + padding_metric = Metric( + dataset_name=self.dataset_name, + name="pct_of_tokens_padded", + value=round(num_padding * 100 / len(pack["tokens"]), 2), + agg_type=AggregationType.MEAN, + ) + pack["metrics"].append(padding_metric) + return { "tokens": torch.tensor(pack["tokens"], dtype=torch.long), "labels": torch.tensor(pack["labels"], dtype=torch.long), @@ -567,6 +629,7 @@ def finalize_pack( "rejected_response_mask": torch.tensor( pack["rejected_response_mask"], dtype=torch.bool ), + "metrics": pack["metrics"], } def _mask_mod( From 382c4e9bd6038e6b54b74e18645932ac52fe5089 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 14:18:04 -0700 Subject: [PATCH 21/48] remove comment --- torchtune/datasets/_iterable_packed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index a788da8db3..7687673ad6 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -216,7 +216,7 @@ def __init__( self.buffer_size = buffer_size self._dataset_name = dataset_name - # Set dataset name on strategy if it supports it + # Set dataset name on the strategy self.strategy.set_dataset_name(dataset_name) self._reset_packer_state() From 5b188ed9c66c80796679a2f8161b3353b8226380 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 1 Jul 2025 21:18:04 -0400 Subject: [PATCH 22/48] update metrics to use handlers --- recipes/full_finetune_distributed.py | 3 +- .../torchtune/data/test_metrics_aggregator.py | 2 +- .../torchtune/data/test_metrics_transform.py | 10 +- tests/torchtune/datasets/test_hf_iterable.py | 2 +- tests/torchtune/datasets/test_interleaved.py | 2 +- torchtune/data/__init__.py | 12 - torchtune/data/_aggregator.py | 5 +- torchtune/data/_metrics.py | 98 ---- torchtune/data/metrics/__init__.py | 39 ++ .../data/metrics/_metric_agg_handlers.py | 433 ++++++++++++++++++ torchtune/data/metrics/_metric_aggregator.py | 271 +++++++++++ torchtune/data/metrics/_metric_transform.py | 124 +++++ torchtune/data/metrics/readme.md | 176 +++++++ torchtune/datasets/_hf_iterable.py | 2 +- torchtune/datasets/_sft.py | 2 +- .../checkpointing/_checkpoint_client.py | 2 +- 16 files changed, 1059 insertions(+), 124 deletions(-) delete mode 100644 torchtune/data/_metrics.py create mode 100644 torchtune/data/metrics/__init__.py create mode 100644 torchtune/data/metrics/_metric_agg_handlers.py create mode 100644 torchtune/data/metrics/_metric_aggregator.py create mode 100644 torchtune/data/metrics/_metric_transform.py create mode 100644 torchtune/data/metrics/readme.md diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 5b65a6e9f5..4c41a81a5b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -24,7 +24,8 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import MetricsAggregator, padded_collate_packed +from torchtune.data import padded_collate_packed +from torchtune.data.metrics import MetricsAggregator from torchtune.datasets import InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index a9fda513a4..0691d9c32d 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -6,7 +6,7 @@ import pytest -from torchtune.data import AggregationType, Metric, MetricsAggregator +from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator class TestMetricsAggregator: diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 2c8f3023b2..8a4a86d7dd 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -6,15 +6,15 @@ import pytest -from torchtune.data import AggregationType, StandardMetricTransform +from torchtune.data.metrics import AggregationType, DefaultTrainingMetricTransform -class TestStandardMetricTransform: - """Tests for StandardMetricTransform functionality.""" +class TestDefaultTrainingMetricTransform: + """Tests for DefaultTrainingMetricTransform functionality.""" def test_dataset_name_not_set_raises_error(self): """Test that using transform without setting dataset name raises error.""" - transform = StandardMetricTransform() + transform = DefaultTrainingMetricTransform() sample = {"tokens": [1, 2, 3]} with pytest.raises(RuntimeError, match="set_dataset_name"): @@ -22,7 +22,7 @@ def test_dataset_name_not_set_raises_error(self): def test_basic_metrics_generation(self): """Test that transform generates expected metrics for a sample.""" - transform = StandardMetricTransform() + transform = DefaultTrainingMetricTransform() transform.set_dataset_name("test_dataset") sample = {"tokens": [1, 2, 3, 4, 5]} diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 83144f0ae9..067fdc1294 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -11,7 +11,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset from .test_iterable_utils import collate_with_metrics, generate_ckpt diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index e1ded110ac..d8afcd2263 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -13,7 +13,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 09292b9ba9..a75e16780a 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data._aggregator import MetricsAggregator from torchtune.data._collate import ( left_pad_sequence, padded_collate, @@ -25,12 +24,6 @@ ShareGPTToMessages, validate_messages, ) -from torchtune.data._metrics import ( - AggregationType, - Metric, - MetricTransform, - StandardMetricTransform, -) from torchtune.data._prompt_templates import ( ChatMLTemplate, GrammarErrorCorrectionTemplate, @@ -42,13 +35,8 @@ from torchtune.data._utils import format_content_with_images, load_image, truncate __all__ = [ - "AggregationType", "CROSS_ENTROPY_IGNORE_IDX", "GrammarErrorCorrectionTemplate", - "Metric", - "MetricsAggregator", - "MetricTransform", - "StandardMetricTransform", "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index 313c5bdc73..66162f826b 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -11,7 +11,7 @@ import torch.distributed as dist -from torchtune.data._metrics import AggregationType, Metric +from torchtune.data.metrics import AggregationType, Metric logger = logging.getLogger(__name__) @@ -159,7 +159,8 @@ def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: n = len(sorted_values) # Each stat becomes its own metric - # For percentiles, it is an approximattion by computing avg of averages + # so that we can all gather O(5) values across ranks + # instead of the entire distribution metrics[(ds_name, f"{metric_name}_mean")] = { "value": sum(values) / n, "agg_type": AggregationType.MEAN, diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py deleted file mode 100644 index 7a38febb1e..0000000000 --- a/torchtune/data/_metrics.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from enum import Enum -from functools import partial -from typing import Any, Callable, Optional, Protocol, Union - - -class AggregationType(Enum): - """Defines how a metric's value should be aggregated.""" - - SUM = "sum" - MEAN = "mean" - DISTRIBUTION = "distribution" - CATEGORICAL_COUNT = "categorical_count" - MAX = "max" - MIN = "min" - - -@dataclass(frozen=True) -class Metric: - """A self-describing metric object.""" - - dataset_name: str - name: str - value: Union[int, float, str] - agg_type: AggregationType - - -class MetricTransform(Protocol): - """Protocol for metric transforms.""" - - def set_dataset_name(self, dataset_name: str) -> None: - ... - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - ... - - -class StandardMetricTransform(MetricTransform): - """ - Attaches per-sample metrics for tracking training progress. - - This transform is responsible for generating metrics on a per-sample - basis (e.g., tokens per sample). The actual aggregation of these metrics - (eg calculating sum of samples seen) is handled by the - `MetricsAggregator`. This separation of concerns ensures that metrics are - correctly aggregated even with multiple dataloader workers and in a - distributed setting. - - Tracked metrics include: - - samples_seen: A count of samples processed. - - tokens_seen: The cumulative sum of all tokens processed. - - seq_len: A distribution of sequence lengths. - """ - - def __init__(self): - # dataset_name is set by the dataset using set_dataset_name - self.dataset_name: Optional[str] = None - self.new_metric: Optional[Callable] = None - - def set_dataset_name(self, dataset_name: str) -> None: - """Called by dataset to set the namespace for metrics. - The dataset name is used to differentiate multiple datasets stats, - e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen".""" - self.dataset_name = dataset_name - self.new_metric = partial(Metric, dataset_name=dataset_name) - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if self.dataset_name is None or self.new_metric is None: - raise RuntimeError( - "set_dataset_name() must be called before using the transform." - ) - - # Determine token key - token_key = "tokens" if "tokens" in sample else "input_ids" - token_len = len(sample.get(token_key, [])) - - # Create metrics for this sample - metrics = [ - self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), - self.new_metric( - name="tokens_seen", value=token_len, agg_type=AggregationType.SUM - ), - self.new_metric( - name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION - ), - ] - - # Append to existing metrics list or create new one - if "metrics" not in sample: - sample["metrics"] = [] - sample["metrics"].extend(metrics) - return sample diff --git a/torchtune/data/metrics/__init__.py b/torchtune/data/metrics/__init__.py new file mode 100644 index 0000000000..17e359d697 --- /dev/null +++ b/torchtune/data/metrics/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data.metrics._metric_aggregator import MetricsAggregator +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_transform import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) + +__all__ = [ + "AggregationType", + "AggregationHandler", + "CategoricalCountAggHandler", + "DefaultTrainingMetricTransform", + "DistributionAggHandler", + "MaxAggHandler", + "MeanAggHandler", + "Metric", + "MetricState", + "MetricsAggregator", + "MetricTransform", + "MinAggHandler", + "SumAggHandler", +] diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py new file mode 100644 index 0000000000..1a1557c803 --- /dev/null +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -0,0 +1,433 @@ +import logging +from abc import ABC, abstractmethod +from collections import Counter, deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Union + +import torch + +from torchtune.data.metrics._metric_transform import Metric, AggregationType + +logger = logging.getLogger(__name__) + +@dataclass +class MetricState: + """Mutable state object representing aggregated metric for (dataset, metric) on a single rank. + + Args: + dataset_name (str): Name of the dataset. + metric_name (str): Name of the metric. + value (float): Current aggregated value, whose meaning depends on the aggregation type + (e.g., running sum, current max). + agg_type (AggregationType): Aggregation type. + metadata (dict[str, Any]): Additional state like count, list of values, etc. + """ + dataset_name: str + metric_name: str + value: float + agg_type: AggregationType + metadata: dict[str, Any] = field(default_factory=dict) + +class AggregationHandler(ABC): + """Base class for handling metric aggregation in MetricsAggregator. + + This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). + Each handler is responsible for: + - Initializing the state for a new (dataset, metric) pair. + - Updating the state with new values. + - Finalizing the value for local (single-rank) logging. + - Reducing the values from all ranks in a distributed setting. + - Serializing and deserializing the metric state for checkpointing. + """ + + @abstractmethod + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + """Create a new MetricState for a (dataset_name, metric_name) pair. + + Args: + dataset_name (str): Name of the dataset. Especially useful when tracking multiple datasets. + metric_name (str): Name of the metric. + agg_type (AggregationType): Aggregation type. + + Returns: + MetricState: New MetricState for this (dataset_name, metric_name) pair. + """ + pass + + @abstractmethod + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + """Update cumulative MetricState with new metric info. + + Args: + local_agg_metric (MetricState): Cumulative state of the aggregation for this metric in the local rank. + metric (Metric): Input metric info. + """ + pass + + @abstractmethod + def finalize_local_agg( + self, local_agg_metric: MetricState + ) -> Union[MetricState, list[MetricState]]: + """ + Computes the final value from the locally aggregated state. + + In a distributed setting, this is called before the reduction step. + This method can also expand a single metric into multiple, for instance, + a distribution into mean, min, max, and percentiles. + + Args: + local_agg_metric (MetricState): The locally aggregated metric state to finalize. + + Returns: + A single `MetricState` or a list of them if the metric expands. + """ + pass + + @abstractmethod + def finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> MetricState: + """ + Merge MetricStates from all ranks into final result. + + Args: + local_agg_metrics (list[MetricState]): list of MetricStates for this (dataset_name, metric_name) pair. + + Returns: + MetricState: Final result for this (dataset_name, metric_name) pair. + """ + pass + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert handler-specific metadata to serializable format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Override this when using non-serializable types like deque or Counter. + For example, convert deque to list, Counter to dict. + """ + return metadata.copy() + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Restore handler-specific metadata from serialized format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Override this to reverse the serialize_metadata transformation. + For example, convert list back to deque, dict back to Counter. + """ + return metadata.copy() + + +class SumAggHandler(AggregationHandler): + """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value += metric.value + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + if not local_agg_metrics: + raise ValueError("Cannot aggregate empty list of metrics") + + total = sum(metric.value for metric in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MaxAggHandler(AggregationHandler): + """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float('-inf'), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value = max(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + max_value = max(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=max_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MinAggHandler(AggregationHandler): + """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float('inf'), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value = min(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + min_value = min(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=min_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MeanAggHandler(AggregationHandler): + """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"sum": 0.0, "count": 0}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["sum"] += metric.value + local_agg_metric.metadata["count"] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + count = local_agg_metric.metadata["count"] + local_agg_metric.value = local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) + total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) + + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total_sum / total_count if total_count > 0 else 0.0, + agg_type=local_agg_metrics[0].agg_type, + metadata={"sum": total_sum, "count": total_count}, + ) + + +class DistributionAggHandler(AggregationHandler): + """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values + and expands into multiple statistical metrics (mean, min, max, percentiles, std). + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + reasonable approximation for monitoring purposes. + """ + + def __init__(self, window_size: int = 1000): + """Initialize handler with specified window size for value retention. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + """ + if window_size <= 0: + raise ValueError(f"window_size must be positive, got {window_size}") + self.window_size = window_size + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"values": deque(maxlen=self.window_size)} + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["values"].append(metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + values = list(local_agg_metric.metadata["values"]) + if not values: + return [] + + return self._compute_distribution_stats(local_agg_metric, values) + + def _compute_distribution_stats( + self, local_agg_metric: MetricState, values: list[float] + ) -> list[MetricState]: + """Compute statistical metrics from distribution values using torch for efficiency.""" + if not values: + return [] + + # Use float64 for precision matching python's float + values_tensor = torch.tensor(values, dtype=torch.float64) + n = len(values_tensor) + + # Compute all stats from the tensor + sum_val = torch.sum(values_tensor).item() + mean_val = sum_val / n + min_val = torch.min(values_tensor).item() + max_val = torch.max(values_tensor).item() + + # Compute all percentiles in one go + percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) + p05_val, p50_val, p95_val = torch.quantile(values_tensor, percentile_definitions).tolist() + + # Return multiple MetricStates with proper agg_types for distributed reduction + # NOTE: Percentiles use MEAN aggregation which approximates global percentiles + # by averaging local percentiles. + metrics = [ + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_mean", + value=mean_val, + agg_type=AggregationType.MEAN, + metadata={"sum": sum_val, "count": n}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_min", + value=min_val, + agg_type=AggregationType.MIN, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_max", + value=max_val, + agg_type=AggregationType.MAX, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p05", + value=p05_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p05_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p50", + value=p50_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p50_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p95", + value=p95_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p95_val, "count": 1}, + ), + ] + # Standard deviation is only well-defined for n > 1 + if n > 1: + std_val = torch.std(values_tensor).item() + metrics.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_std", + value=std_val, + agg_type=AggregationType.MEAN, + metadata={"sum": std_val, "count": 1}, + ) + ) + return metrics + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.DISTRIBUTION are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert deque to list for serialization.""" + serialized = metadata.copy() + if "values" in serialized: + serialized["values"] = list(serialized["values"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert list back to deque.""" + deserialized = metadata.copy() + if "values" in deserialized: + deserialized["values"] = deque(deserialized["values"], maxlen=self.window_size) + return deserialized + + +class CategoricalCountAggHandler(AggregationHandler): + """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values + and expands into individual count metrics for each category.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"counts": Counter()} + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["counts"][metric.value] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + # Expand categorical counts into individual metrics + results = [] + for category, count in local_agg_metric.metadata["counts"].items(): + results.append(MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_{category}_count", + value=count, + agg_type=AggregationType.SUM + )) + return results + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.CATEGORICAL_COUNT are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert Counter to dict for serialization.""" + serialized = metadata.copy() + if "counts" in serialized: + serialized["counts"] = dict(serialized["counts"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert dict back to Counter.""" + deserialized = metadata.copy() + if "counts" in deserialized: + deserialized["counts"] = Counter(deserialized["counts"]) + return deserialized \ No newline at end of file diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py new file mode 100644 index 0000000000..c07f0dea36 --- /dev/null +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +from collections import defaultdict +from typing import Any, tuple + +import torch.distributed as dist + +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MetricState, + MaxAggHandler, + MeanAggHandler, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_transform import Metric, AggregationType + +class MetricsAggregator: + """Aggregates metrics across datasets and distributed ranks using pluggable handlers. + + Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) + has its own handler. Maintains only one state per (dataset, metric) pair. + + When preparing for logging, uses a two-phase approach: + 1. Local aggregation: Each rank aggregates its metrics independently + 2. Distributed reduction: Results combined across ranks + + The aggregator is checkpointable and restores from state_dict for training resumption. + + Args: + dist_window_size (int): Window size for DistributionAggHandler tracking. + + Example: + >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + >>> + >>> # Create aggregator + >>> aggregator = MetricsAggregator() + >>> + >>> # Sample metrics from different batches + >>> batch1_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> batch2_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> # Update with metrics + >>> aggregator.update(batch1_metrics) + >>> aggregator.update(batch2_metrics) + >>> + >>> # Get final results + >>> results = aggregator.get_metrics_for_logging(prefix="train") + >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + """ + + def __init__(self, dist_window_size: int = 1000): + if dist_window_size <= 0: + raise ValueError(f"dist_window_size must be positive, got {dist_window_size}") + + # Storage: {(dataset, metric): MetricState} - O(unique metrics) not O(samples) + self._metric_states: dict[tuple[str, str], MetricState] = {} + self._dist_window_size = dist_window_size + + # Create handler registry - all handlers initialized upfront + self._handlers: dict[AggregationType, AggregationHandler] = { + AggregationType.SUM: SumAggHandler(), + AggregationType.MAX: MaxAggHandler(), + AggregationType.MIN: MinAggHandler(), + AggregationType.MEAN: MeanAggHandler(), + AggregationType.DISTRIBUTION: DistributionAggHandler(dist_window_size), + AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), + } + + def register_handler(self, agg_type: AggregationType, handler: AggregationHandler) -> None: + """Register custom aggregation handler for specified type. + + Args: + agg_type (AggregationType): The aggregation type to handle + handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface + """ + self._handlers[agg_type] = handler + + def update(self, metrics: list[Metric]) -> None: + """Update (dataset_name, metric_name) metric state with new values. + + Args: + metrics (list[Metric]): List of metrics to update the state with + """ + for metric in metrics: + metric_key = (metric.dataset_name, metric.name) + handler = self._handlers.get(metric.agg_type) + + if handler is None: + raise ValueError(f"No handler registered for aggregation type: {metric.agg_type}") + + if metric_key not in self._metric_states: + self._metric_states[metric_key] = handler.initialize_metric_state( + metric.dataset_name, metric.name, metric.agg_type + ) + + local_agg_metric = self._metric_states[metric_key] + handler.update(local_agg_metric, metric) # Mutates local_agg_metric + + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: + """Get final metrics for logging in standard format. + + Args: + prefix (str): Prefix for metric names in the returned dictionary + + Returns: + dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" + and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, + `metric_name="loss"`, the key would be `train_alpaca/loss`. + """ + final_results = self._compute_unified_metrics() + + return { + f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value + for result in final_results + } + + def _compute_unified_metrics(self) -> list[MetricState]: + """ + Compute metrics handling both local and distributed cases uniformly. + + Returns: + list[MetricState]: Final results ready for logging + """ + # Step 1: Get local results from all handlers (may expand distributions/categoricals) + prepared_results = [] + for local_agg_metric in self._metric_states.values(): + handler = self._handlers[local_agg_metric.agg_type] + prepared = handler.finalize_local_agg(local_agg_metric) + if isinstance(prepared, list): # Distribution/categorical expands to multiple + prepared_results.extend(prepared) + else: + prepared_results.append(prepared) + + # Step 2: Apply distributed reduction if needed + if dist.is_initialized() and dist.get_world_size() > 1: + prepared_results = self._finalize_dist_agg(prepared_results) + + return prepared_results + + def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[MetricState]: + """Apply distributed reduction to local results. + + Args: + local_agg_metrics (list[MetricState]): (dataset_name, metric_name) metric pairs from this rank + + Returns: + list[MetricState]: Reduced results combining all ranks + """ + world_size = dist.get_world_size() + + # Gather all results from all ranks + all_results = [None] * world_size + dist.all_gather_object(all_results, local_agg_metrics) + + # Group by (dataset_name, metric_name) for reduction + grouped = defaultdict(list) + for rank_results in all_results: + if rank_results: # Handle ranks with no metrics + for result in rank_results: + result_key = (result.dataset_name, result.metric_name) + grouped[result_key].append(result) + + # Apply handler-specific distributed reduction + reduced_results = [] + for result_key, results_list in grouped.items(): + if not results_list: + continue # Skip empty groups + + # All results for a key should have same agg_type + agg_type = results_list[0].agg_type + handler = self._handlers[agg_type] + reduced_result = handler.finalize_dist_agg(results_list) + reduced_results.append(reduced_result) + + return reduced_results + + def state_dict(self) -> dict[str, Any]: + """Serialize aggregator state for checkpointing. + + Returns: + dict[str, Any]: Serializable dictionary containing all aggregator state + """ + serializable_state = {} + required_agg_types = set() # Track aggregation types used in saved states + + for metric_key, local_agg_metric in self._metric_states.items(): + # Get handler for this result's aggregation type + handler = self._handlers[local_agg_metric.agg_type] + required_agg_types.add(local_agg_metric.agg_type) + + # Convert MetricState to serializable dict + result_dict = { + "dataset_name": local_agg_metric.dataset_name, + "metric_name": local_agg_metric.metric_name, + "value": local_agg_metric.value, + "agg_type": local_agg_metric.agg_type, + "metadata": handler.serialize_metadata(local_agg_metric.metadata) + } + + # Convert tuple key to string for JSON compatibility + serializable_state[str(metric_key)] = result_dict + + return { + "state": serializable_state, + "dist_window_size": self._dist_window_size, + "required_agg_types": list(required_agg_types) # Save which handlers are needed + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load aggregator state from checkpoint. + + Args: + state_dict (dict[str, Any]): Dictionary containing serialized aggregator state + + Raises: + ValueError: If required handlers are missing after checkpoint restore + """ + self._dist_window_size = state_dict.get("dist_window_size", 1000) + + # Sanity check: Ensure all required handlers are available + required_agg_types = state_dict.get("required_agg_types", []) + missing_handlers = [] + for agg_type in required_agg_types: + if agg_type not in self._handlers: + missing_handlers.append(agg_type) + + if missing_handlers: + raise ValueError( + f"Missing handlers for aggregation types: {missing_handlers}. " + f"Custom handlers must be re-registered before checkpoint restore." + ) + + deserialized_state = {} + for key_str, result_dict in state_dict["state"].items(): + # Convert string keys back to tuples + metric_key = ast.literal_eval(key_str) + + # Get handler for this aggregation type + agg_type = result_dict["agg_type"] + handler = self._handlers[agg_type] + + # Restore metadata using handler-specific deserialization + metadata = handler.deserialize_metadata(result_dict["metadata"]) + + # Create MetricState from dict + local_agg_metric = MetricState( + dataset_name=result_dict["dataset_name"], + metric_name=result_dict["metric_name"], + value=result_dict["value"], + agg_type=result_dict["agg_type"], + metadata=metadata + ) + + deserialized_state[metric_key] = local_agg_metric + + self._metric_states = deserialized_state diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py new file mode 100644 index 0000000000..0e252cb2e6 --- /dev/null +++ b/torchtune/data/metrics/_metric_transform.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional, Union + +from torchtune.modules.transforms import Transform + +@dataclass(frozen=True) +class Metric: + dataset_name: str + name: str + value: Union[int, float, str] + agg_type: "AggregationType" + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated.""" + + SUM = "sum" + MEAN = "mean" + DISTRIBUTION = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + +class MetricTransform(Transform): + """Applied to each sample to generate per-sample metrics for training tracking. + + Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader + workers and in distributed settings.""" + + def __init__(self): + # dataset_name is set by the dataset using set_dataset_name + self.dataset_name: Optional[str] = None + self.new_metric: Optional[Callable] = None + + def set_dataset_name(self, dataset_name: str) -> None: + """Called by dataset to set the namespace for metrics. + + The dataset name is used to differentiate multiple datasets stats, + e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". + + Args: + dataset_name (str): Name of the dataset for metric namespacing + """ + self.dataset_name = dataset_name + # Create a partial to make it easier to create new metrics + self.new_metric = partial(Metric, dataset_name=dataset_name) + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + """Generate metrics for a single sample. Must be implemented by subclasses. + + Args: + sample (dict[str, Any]): The sample dictionary to generate metrics from + + Returns: + list[Metric]: List of metrics generated for this sample + """ + raise NotImplementedError( + "Subclasses must implement _generate_metrics method" + ) + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Apply transform to sample, adding generated metrics. """ + if self.dataset_name is None or self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Generate metrics for this sample + metrics = self._generate_metrics(sample) + + # Add to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample + + +class DefaultTrainingMetricTransform(MetricTransform): + """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. + + For details about MetricTransform base class behavior, see the parent class docstring. + + Tracked metrics: + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> + >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens + >>> metrics = transform._generate_metrics(sample) + >>> # Creates: + >>> # [ + >>> # Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) + >>> # ] + """ + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + return [ + self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), + self.new_metric( + name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + ), + self.new_metric( + name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + ), + ] diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md new file mode 100644 index 0000000000..5b79c5e98b --- /dev/null +++ b/torchtune/data/metrics/readme.md @@ -0,0 +1,176 @@ +# TorchTune Metrics Module + +## Overview + +The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. + +## Architecture Overview + +``` +┌────────────────────────────────────────────────────┐ +│ Training Loop │ +└─────────────────────┬──────────────────────────────┘ + │ +┌─────────────────────▼──────────────────────────────┐ +│ MetricTransform │ +│ • Applied to each sample │ +│ • Generates per-sample metrics │ +│ • Examples: tokens_seen, seq_len, samples_seen │ +└─────────────────────┬──────────────────────────────┘ + │ list[Metric] +┌─────────────────────▼──────────────────────────────┐ +│ MetricsAggregator │ +│ • Aggregates metrics across samples and ranks │ +│ • Uses pluggable AggregationHandlers │ +│ • Handles distributed reduction │ +└─────────────────────┬──────────────────────────────┘ + │ {prefix_dataset/metric: value} +┌─────────────────────▼──────────────────────────────┐ +│ Logging System │ +│ • W&B, TensorBoard, etc. │ +│ • Gets formatted metrics ready for logging │ +└────────────────────────────────────────────────────┘ +``` + +## File Structure + +- **`_metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes +- **`_metric_agg_handlers.py`**: Aggregation strategy implementations +- **`_metric_aggregator.py`**: Main aggregator orchestrating the handlers + +## Customizing metrics + +- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics +- **Handler registration**: Register custom handlers for specialized aggregation needs + +####### +## TODO +## Move this from here to website docs +####### + +## Core Components + +### 1. MetricTransform +Generates per-sample metrics during data processing. + +**Key Features:** +- Applied to each sample in the dataset +- Creates `Metric` objects with dataset name, metric name, value, and aggregation type +- Handles dataset namespacing for multi-dataset scenarios + +**Example Usage:** +```python +from torchtune.data.metrics import DefaultTrainingMetricTransform, AggregationType + +transform = DefaultTrainingMetricTransform() +transform.set_dataset_name("alpaca") + +# Applied to each sample +sample = {"tokens": [1, 2, 3, 4, 5]} +sample = transform(sample) +# sample["metrics"] now contains: +# [ +# Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) +# ] +``` + +### 2. MetricsAggregator +Efficiently aggregates metrics across samples and distributed ranks. + +**Key Features:** +- Handler-based strategy pattern for different aggregation types +- Distributed-aware with automatic rank reduction +- Checkpointable state for training resumption +- Keep track of (metric, dataset) pairs + +**Aggregation Types (at the time of writing):** +- `SUM`: Cumulative totals (e.g., total tokens processed) +- `MEAN`: Running averages (e.g., average loss) +- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) +- `DISTRIBUTION`: Statistical summaries (mean, min, max, percentiles) +- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) + +**Example Usage:** +```python +from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + +# Create aggregator +aggregator = MetricsAggregator() + +# Sample metrics from different batches +batch1_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +batch2_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +# Update with metrics +aggregator.update(batch1_metrics) +aggregator.update(batch2_metrics) + +# Get final results +results = aggregator.get_metrics_for_logging(prefix="train") +# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} +``` + +### 3. AggregationHandlers +Pluggable strategies for different aggregation patterns. + +``` +AggregationHandler (ABC) +├── SumAggHandler # value += metric.value +├── MeanAggHandler # tracks sum and count +├── MaxAggHandler # value = max(value, metric.value) +├── MinAggHandler # value = min(value, metric.value) +├── DistributionAggHandler # maintains value window + stats +└── CategoricalCountAggHandler # Counter for categories +``` + +**Custom Handler Example:** +```python +class CustomAggHandler(AggregationHandler): + def initialize_metric_state(self, dataset_name, metric_name, agg_type): + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=, # should change + agg_type=agg_type, + metadata={} # may need to change + ) + + def update(self, local_agg_metric, metric): + ... + + def finalize_local_agg(self, local_agg_metric): + ... + + def finalize_dist_agg(self, local_agg_metrics): + ... + +# Register with aggregator +aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) +``` + +## Distributed Training Support + +The metrics system automatically handles distributed environments: + +1. **Local Aggregation**: Each rank aggregates its own metrics +2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` +3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) + +**Distributed Flow:** +``` +Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] +Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] + ↓ + AllGather + Reduce + ↓ + Final Results [(ds1, metric1), (ds1, metric2)] +``` \ No newline at end of file diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 0be4c0cc53..a9495827cf 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -12,7 +12,7 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from torchtune.data._metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.data.metrics import AggregationType, Metric, StandardMetricTransform from torchtune.datasets._iterable_base import TuneIterableDataset logger = logging.getLogger(__name__) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 6dabee9bb6..72289c14dc 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -12,7 +12,7 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages -from torchtune.data._metrics import StandardMetricTransform +from torchtune.data.metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.modules.transforms import Transform diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 05fd46e395..d943ad697f 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -19,7 +19,7 @@ StateDictOptions, ) from torchtune import config, training, utils -from torchtune.data import MetricsAggregator +from torchtune.data.metrics import MetricsAggregator from torchtune.modules.optim import OptimizerInBackward from torchtune.modules.peft import ( get_adapter_state_dict, From 2eab08db738be364b0b9edfe4218cb6c6fb8f281 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 1 Jul 2025 21:21:32 -0400 Subject: [PATCH 23/48] remove file after refactoring --- torchtune/data/_aggregator.py | 343 ---------------------------------- 1 file changed, 343 deletions(-) delete mode 100644 torchtune/data/_aggregator.py diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py deleted file mode 100644 index 66162f826b..0000000000 --- a/torchtune/data/_aggregator.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ast -import collections -import logging -from typing import Any - -import torch.distributed as dist - -from torchtune.data.metrics import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -class MetricsAggregator: - """ - Aggregates metrics across datasets and distributed ranks. - - The internal state `_state` is a dictionary where the key is a tuple - of `(dataset_name, metric_name)` and the value is another dictionary - holding the metric's specific state (e.g., `{'type': AggregationType.SUM, 'value': 10}`). - - Usage: - aggregator = MetricsAggregator() - aggregator.update(metrics) - # Get logger-ready metrics {key: value} - metrics = aggregator.get_metrics_for_logging(prefix="train") # {"train/dataset1/tokens": 1234, ...} - """ - - def __init__(self, dist_window_size: int = 1000): - # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} - self._state: dict[tuple[str, str], dict[str, Any]] = {} - - # For distributions, we keep a window of values to compute percentiles - self._dist_window_size = dist_window_size - - def update(self, metrics: list[Metric]) -> None: - """Update internal state with new metrics. - - Args: - metrics (list[Metric]): list of Metric objects - """ - for metric in metrics: - key = (metric.dataset_name, metric.name) - - if key not in self._state: - self._initialize_state(key, metric.agg_type) - - state = self._state[key] - - # Update based on aggregation type - if metric.agg_type == AggregationType.SUM: - state["value"] += metric.value - elif metric.agg_type == AggregationType.MAX: - if state["value"] is not None: - state["value"] = max(state["value"], metric.value) - else: - state["value"] = metric.value - elif metric.agg_type == AggregationType.MIN: - if state["value"] is not None: - state["value"] = min(state["value"], metric.value) - else: - state["value"] = metric.value - elif metric.agg_type == AggregationType.MEAN: - state["sum"] += metric.value - state["count"] += 1 - elif metric.agg_type == AggregationType.DISTRIBUTION: - state["values"].append(metric.value) - elif metric.agg_type == AggregationType.CATEGORICAL_COUNT: - state["counts"][metric.value] += 1 - - def _initialize_state( - self, key: tuple[str, str], agg_type: AggregationType - ) -> None: - """Initialize state for a new metric.""" - self._state[key] = {"type": agg_type} - state = self._state[key] - - if agg_type == AggregationType.SUM: - state["value"] = 0.0 - elif agg_type in (AggregationType.MAX, AggregationType.MIN): - state["value"] = None - elif agg_type == AggregationType.MEAN: - state["sum"] = 0.0 - state["count"] = 0 - elif agg_type == AggregationType.DISTRIBUTION: - state["values"] = collections.deque(maxlen=self._dist_window_size) - elif agg_type == AggregationType.CATEGORICAL_COUNT: - state["counts"] = collections.Counter() - - def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: - """ - Returns aggregated metrics ready for logging to wandb/tensorboard. - - Args: - prefix (str): Optional prefix like "train" or "valid" for metric keys - - Returns: - dict[str, float]: Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value - Ready to be logged directly: wandb.log(metrics) - """ - # Always compute local metrics first - local_metrics = self._compute_local_metrics() - - # In distributed mode, perform reduction - if dist.is_initialized() and dist.get_world_size() > 1: - metrics = self._compute_distributed_metrics(local_metrics) - else: - metrics = local_metrics - - # Format for logging with proper key structure - return self._format_for_logging(metrics, prefix) - - def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: - """ - Compute metrics from current state. - - For distributions and categoricals, expands into multiple entries. - The dict format allows future extensions with additional fields. - - Returns: - dict[tuple[str, str], dict[str, Any]]: dictionary mapping - (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} - """ - metrics = {} - - for (ds_name, metric_name), state in self._state.items(): - agg_type = state["type"] - - if agg_type in ( - AggregationType.SUM, - AggregationType.MAX, - AggregationType.MIN, - ): - # For sum, max, and min, we just need to return the value - metrics[(ds_name, metric_name)] = { - "value": state["value"], - "agg_type": agg_type, - } - - elif agg_type == AggregationType.MEAN: - if state["count"] > 0: - value = state["sum"] / state["count"] - metrics[(ds_name, metric_name)] = { - "value": value, - "agg_type": agg_type, - } - - elif agg_type == AggregationType.DISTRIBUTION: - # queue -> list - values = list(state["values"]) - - # Sort to get percentiles efficiently - sorted_values = sorted(values) - n = len(sorted_values) - - # Each stat becomes its own metric - # so that we can all gather O(5) values across ranks - # instead of the entire distribution - metrics[(ds_name, f"{metric_name}_mean")] = { - "value": sum(values) / n, - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_min")] = { - "value": sorted_values[0], - "agg_type": AggregationType.MIN, - } - metrics[(ds_name, f"{metric_name}_max")] = { - "value": sorted_values[-1], - "agg_type": AggregationType.MAX, - } - metrics[(ds_name, f"{metric_name}_p05")] = { - "value": sorted_values[max(0, int(0.05 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_p50")] = { - "value": sorted_values[max(0, int(0.5 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_p95")] = { - "value": sorted_values[max(0, int(0.95 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - - elif agg_type == AggregationType.CATEGORICAL_COUNT: - # Expand categorical counts into individual metrics - for category, count in state["counts"].items(): - metrics[(ds_name, f"{metric_name}_{category}_count")] = { - "value": count, - "agg_type": AggregationType.SUM, - } - - return metrics - - def _compute_distributed_metrics( - self, local_metrics: dict[tuple[str, str], dict[str, Any]] - ) -> dict[tuple[str, str], dict[str, Any]]: - """ - Performs distributed reduction on metrics. - - Strategy: - 1. Do a single all_gather_object to collect all metrics from all ranks - 2. Group metrics by key and aggregation type - 3. Apply the appropriate reduction operation locally - - This avoids complex tensor operations and handles all reduction in one pass. - - Args: - local_metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping - (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - - Returns: - dict[tuple[str, str], dict[str, Any]]: Reduced metrics in same format as input - - Example: - rank_1_metrics = - { - ("ds1", "metric1"): {"value": 10, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 20, "agg_type": AggregationType.MEAN}, - } - rank_2_metrics = - { - ("ds1", "metric1"): {"value": 30, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 40, "agg_type": AggregationType.MEAN}, - } - - # After reduction - result = - { - ("ds1", "metric1"): {"value": 40, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 30, "agg_type": AggregationType.MEAN}, - } - """ - world_size = dist.get_world_size() - - # Gather all metrics from all ranks in one operation - all_metrics = [None] * world_size - dist.all_gather_object(all_metrics, local_metrics) - - # Group values by key for reduction - grouped = collections.defaultdict(list) - for rank_metrics in all_metrics: - if rank_metrics: # It's possible a rank has no metrics - for key, metric_dict in rank_metrics.items(): - # A key is a tuple (dataset, metric) - grouped[key].append(metric_dict) - - # Reduce based on aggregation type - reduced = {} - if not grouped: - return reduced - - for key, metric_dicts in grouped.items(): - # All metrics for a key should have same type, just take first - values = [m["value"] for m in metric_dicts] - agg_type = metric_dicts[0]["agg_type"] - - # Start with copy of first dict to preserve any extra fields - result_dict = metric_dicts[0].copy() - - if agg_type == AggregationType.SUM: - result_dict["value"] = sum(values) - elif agg_type == AggregationType.MAX: - result_dict["value"] = max(values) - elif agg_type == AggregationType.MIN: - result_dict["value"] = min(values) - elif agg_type == AggregationType.MEAN: - result_dict["value"] = sum(values) / len(values) - - reduced[key] = result_dict - - return reduced - - def _format_for_logging( - self, - metrics: dict[tuple[str, str], dict[str, Any]], - prefix: str, - template: str = r"{prefix}_{ds_name}/{metric_name}", - ) -> dict[str, float]: - """ - Format metrics for wandb/tensorboard logging. - - Args: - metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping - (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - prefix (str): Optional prefix like "train" or "valid" - template (str): Template for metric key. Use {prefix}, {ds_name}, and {metric_name} as placeholders. - - Returns: - dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float - """ - formatted = {} - - for (ds_name, metric_name), metric_dict in metrics.items(): - # Use regex format to build key - key = template.format( - prefix=prefix, ds_name=ds_name, metric_name=metric_name - ) - formatted[key] = metric_dict["value"] - - return formatted - - def state_dict(self) -> dict[str, Any]: - """Serialize aggregator state. The state is almost directly serializable.""" - serializable_state = {} - for key, state in self._state.items(): - state_copy = state.copy() - - # Convert non-serializable types - if "values" in state_copy: - state_copy["values"] = list(state_copy["values"]) # deque → list - if "counts" in state_copy: - state_copy["counts"] = dict(state_copy["counts"]) # Counter → dict - - # Convert tuple key to string for JSON compatibility - # JSON doesn't support tuple keys, so we convert (dataset, metric) → "('dataset', 'metric')" - serializable_state[str(key)] = state_copy - return {"state": serializable_state, "dist_window_size": self._dist_window_size} - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load aggregator state from checkpoint.""" - self._dist_window_size = state_dict["dist_window_size"] - - deserialized_state = {} - for key_str, state in state_dict["state"].items(): - # Convert string keys back to tuples - # "('dataset', 'metric')" → ('dataset', 'metric') - key = ast.literal_eval(key_str) - - # Re-wrap values in their original types - if state.get("type") == AggregationType.DISTRIBUTION: - state["values"] = collections.deque( - state["values"], maxlen=self._dist_window_size - ) - if state.get("type") == AggregationType.CATEGORICAL_COUNT: - state["counts"] = collections.Counter(state["counts"]) - - deserialized_state[key] = state - self._state = deserialized_state From 58491f1f90aabdab242b59bf79094b9cc3ad82b9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 10:05:50 -0400 Subject: [PATCH 24/48] add distributed tsts --- .../torchtune/data/test_metrics_aggregator.py | 159 ++++++++++++++++++ tests/torchtune/datasets/test_hf_iterable.py | 101 +++++++++++ tests/torchtune/datasets/test_interleaved.py | 110 ++++++++++++ 3 files changed, 370 insertions(+) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 0691d9c32d..bbfb9821cd 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator @@ -147,3 +150,159 @@ def test_prefix_handling(self): result_no_prefix = aggregator.get_metrics_for_logging() assert result_no_prefix["data_test_ds/metric1"] == 42 assert result_no_prefix["data_test_ds/metric2"] == 84 + + +class TestDistributedMetricsAggregator(FSDPTest): + """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_all_aggregation_types(self): + """ + Test that all aggregation types work correctly in distributed setting. + Each rank contributes different values to ensure proper reduction across ranks. + """ + aggregator = MetricsAggregator() + rank = dist.get_rank() + + # Each rank contributes different values to test cross-rank aggregation + base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20 + + metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 10, AggregationType.MAX), + Metric("test", "min_metric", base_value // 2, AggregationType.MIN), + ] + + # DISTRIBUTION: Each rank adds 5 values for distribution statistics + # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] + for i in range(5): + metrics.append( + Metric("test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION) + ) + + # CATEGORICAL_COUNT: Different categories per rank to test counting + # rank 0: 3 of cat_A, 2 of cat_B + # rank 1: 1 of cat_A, 4 of cat_C + if rank == 0: + metrics.extend([ + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), + ]) + else: + metrics.extend([ + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + ]) + + # Update aggregator and get results + aggregator.update(metrics) + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify aggregation results across all ranks + # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MAX: rank 0 has 100, rank 1 has 200 -> max 200 + # MIN: rank 0 has 5, rank 1 has 10 -> min 5 + assert result["train_test/sum_metric"] == 30 + assert result["train_test/mean_metric"] == 20 + assert result["train_test/max_metric"] == 200 + assert result["train_test/min_metric"] == 5 + + # DISTRIBUTION: Combined values [0,1,2,3,4,10,11,12,13,14] + # Mean should be average of local means: (2 + 12) / 2 = 7 + assert result["train_test/dist_metric_mean"] == 7 + assert result["train_test/dist_metric_min"] == 0 + assert result["train_test/dist_metric_max"] == 14 + + # CATEGORICAL_COUNT: Total counts across ranks + # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 + assert result["train_test/cat_metric_cat_A_count"] == 4 + assert result["train_test/cat_metric_cat_B_count"] == 2 + assert result["train_test/cat_metric_cat_C_count"] == 4 + + @gpu_test(gpu_count=2) + def test_distributed_state_dict_resumption(self): + """ + Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. + Verifies: + - State can be saved after partial updates across ranks + - State can be restored consistently across ranks + - Continued updates after restore produce identical results + - Distributed aggregation works correctly after restoration + """ + rank = dist.get_rank() + + # Phase 1: Create aggregator and add initial metrics + aggregator1 = MetricsAggregator() + + # Each rank contributes different initial values + base_value = rank * 100 # rank 0: 0, rank 1: 100 + + initial_metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 2, AggregationType.MAX), + ] + + # Add some DISTRIBUTION values - each rank adds 3 values + for i in range(3): + initial_metrics.append( + Metric("test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION) + ) + + # Add CATEGORICAL_COUNT values + if rank == 0: + initial_metrics.extend([ + Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), + ]) + else: + initial_metrics.extend([ + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + ]) + + aggregator1.update(initial_metrics) + + # Save state_dict after initial update + state_dict = aggregator1.state_dict() + + # Phase 2: Create new aggregator and restore from state_dict + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state_dict) + + # Verify both aggregators produce identical results after restore + result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") + result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") + assert result1 == result2, ( + f"Rank {rank}: Aggregators differ after state_dict restore" + ) + + # Phase 3: Add more metrics to both aggregators + additional_metrics = [ + Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), + Metric("test", "min_metric", rank * 1000, AggregationType.MIN), + ] + + # Update both aggregators with additional metrics + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + # Phase 4: Verify final results are identical across both aggregators + final_result1 = aggregator1.get_metrics_for_logging(prefix="final") + final_result2 = aggregator2.get_metrics_for_logging(prefix="final") + assert final_result1 == final_result2, ( + f"Rank {rank}: Final results differ after continued updates" + ) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 067fdc1294..ba644515c2 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,10 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math +import tempfile +import shutil from itertools import islice from pathlib import Path import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test from torchdata.stateful_dataloader import StatefulDataLoader @@ -241,3 +247,98 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): assert all( epoch_value == 1 for epoch_value in epoch_values ), f"Epoch values should be 1, got {epoch_values}" + + +class TestDistributedHfIterableDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_epoch_boundary_checkpointing(self): + """ + Test epoch boundary handling with checkpointing in distributed setting. + Ensures proper handling of: + - Checkpointing at 0.9, 1.0, and 2.5 epoch boundaries + - Correct sample distribution across epochs + - Proper state restoration after checkpointing + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="epoch_test_") + else: + temp_dir = "" + + # Broadcast temp directory path to all ranks + temp_dir_list = [temp_dir] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + medium_dataset_file = tmp_path / "medium_data.json" + + # Only rank 0 creates the data file, all ranks read from it + if rank == 0: + create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE) + dist.barrier() # Wait for file creation + + # Test multiple epoch boundaries + for num_epochs in [0.9, 1.0, 2.5]: + def create_loader_and_aggregator(): + dataset = HfIterableDataset( + path="json", + data_files=str(medium_dataset_file), + split="train", + dataset_name="epoch_test", + seed=SEED, + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + ) + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics, num_workers=0 + ) + return loader, MetricsAggregator() + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # Calculate steps to reach desired epoch boundary + samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() + total_samples = int(samples_per_rank * num_epochs) + total_steps = total_samples // BATCH_SIZE + + if total_steps < 2: + raise ValueError(f"Not enough steps for meaningful test: {total_steps}") + + # Split steps between before and after checkpoint + steps_before = max(1, total_steps // 2) + steps_after = total_steps - steps_before + + result = generate_ckpt( + loader1, aggregator1, steps_before, steps_after, + resume_dataloader=loader2, resume_aggregator=aggregator2 + ) + + # Verify deterministic resumption - critical for distributed training + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic resume for {num_epochs} epochs. " + f"This indicates checkpoint/resume state is not properly preserved." + ) + + # Verify epoch metric is correctly tracked + final_metrics = result["final_metrics"] + expected_epoch = math.floor(num_epochs - 1e-9) # -1e-9 so 1.0 epochs -> 0 + assert final_metrics[f"train_epoch_test/num_epochs"] == expected_epoch, ( + f"Epoch count incorrect for {num_epochs} epochs test scenario" + ) + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index d8afcd2263..d02912cd0f 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,11 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile +import shutil from itertools import islice from pathlib import Path from unittest.mock import patch import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test import torch from torchdata.stateful_dataloader import StatefulDataLoader @@ -233,3 +238,108 @@ def create_interleaved(): assert ( result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" + + +class TestDistributedInterleavedDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_interleaved_checkpointing(self): + """ + Test interleaved dataset checkpointing with distributed settings. + Assertions: + - Each rank processes non-overlapping data shards + - Sampling ratios (70/30) are maintained across ranks + - Checkpoint/resume produces identical batches (deterministic) + - Metrics correctly aggregate across ranks + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="interleaved_test_") + else: + temp_dir = None + + # Broadcast temp directory to all ranks + temp_dir_list = [temp_dir] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + def create_dataset(): + file1 = tmp_path / "ds1.json" + file2 = tmp_path / "ds2.json" + + # Only rank 0 creates the data files + if rank == 0: + create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 + create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100) # IDs 100-134 + dist.barrier() # Wait for file creation + + ds1 = HfIterableDataset( + path="json", data_files=str(file1), split="train", dataset_name="ds1", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + ) + ds2 = HfIterableDataset( + path="json", data_files=str(file2), split="train", dataset_name="ds2", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + ) + + # Create interleaved dataset with 70/30 weighting + return InterleavedDataset([ds1, ds2], [0.8, 0.2], seed=SEED) + + def create_dataloader(dataset): + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, + num_workers=0, # Avoid multiprocessing in distributed tests + collate_fn=collate_with_metrics + ) + return loader, MetricsAggregator() + + # Run checkpointing test with small number of steps + loader1, aggregator1 = create_dataloader(create_dataset()) + loader2, aggregator2 = create_dataloader(create_dataset()) + + result = generate_ckpt( + loader1, aggregator1, 3, 3, # 3 steps before, 3 steps after checkpoint + resume_dataloader=loader2, resume_aggregator=aggregator2 + ) + + # Verify deterministic resumption + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic interleaved resume. " + f"This indicates sampling state is not properly preserved." + ) + assert result["final_metrics"] == result["resumed_metrics"], ( + "Final metrics don't match resumed metrics - aggregator state issue" + ) + + # Verify sampling ratio is approximately maintained (80/20 split) + all_ids = [] + for batch in result["pre_checkpoint_batches"] + result["post_checkpoint_batches"]: + all_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 + ds1_samples = sum(1 for id in all_ids if id < 100) + ds2_samples = sum(1 for id in all_ids if id >= 100) + total_samples = ds1_samples + ds2_samples + + if total_samples > 0: + ds1_ratio = ds1_samples / total_samples + assert 0.6 < ds1_ratio < 1.0, ( + f"Rank {rank}: Dataset sampling ratio {ds1_ratio:.2f} outside expected " + f"range for 80/20 split. Got {ds1_samples}, {ds2_samples} samples." + ) + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) From 96424d0df3868c45211e080a09e1a9e5ef59b0b8 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 09:30:01 -0700 Subject: [PATCH 25/48] tests pass --- .../torchtune/data/test_metrics_aggregator.py | 153 ++++++++---- tests/torchtune/datasets/test_hf_iterable.py | 48 ++-- tests/torchtune/datasets/test_interleaved.py | 60 +++-- .../torchtune/datasets/test_iterable_utils.py | 2 +- torchtune/data/metrics/__init__.py | 2 +- .../data/metrics/_metric_agg_handlers.py | 233 +++++++++++------- torchtune/data/metrics/_metric_aggregator.py | 159 ++++++------ torchtune/data/metrics/_metric_transform.py | 51 ++-- torchtune/data/metrics/readme.md | 16 +- torchtune/datasets/_hf_iterable.py | 8 +- torchtune/datasets/_sft.py | 4 +- 11 files changed, 452 insertions(+), 284 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index bbfb9821cd..b65c11f533 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -6,8 +6,8 @@ import pytest import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator @@ -64,9 +64,7 @@ def test_distribution_metrics(self): assert result["train_test/dist_metric_mean"] == 5.5 assert result["train_test/dist_metric_min"] == 1 assert result["train_test/dist_metric_max"] == 10 - assert ( - result["train_test/dist_metric_p50"] == 5 - ) # Median of 1-10 is 5 (index 4, value 5) + assert result["train_test/dist_metric_p50"] == 5.5 def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -182,28 +180,54 @@ def test_distributed_all_aggregation_types(self): # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] for i in range(5): metrics.append( - Metric("test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION) + Metric( + "test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION + ) ) # CATEGORICAL_COUNT: Different categories per rank to test counting # rank 0: 3 of cat_A, 2 of cat_B # rank 1: 1 of cat_A, 4 of cat_C if rank == 0: - metrics.extend([ - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), - ]) + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + ] + ) else: - metrics.extend([ - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - ]) + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + ] + ) # Update aggregator and get results aggregator.update(metrics) @@ -211,7 +235,7 @@ def test_distributed_all_aggregation_types(self): # Verify aggregation results across all ranks # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 - # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 # MAX: rank 0 has 100, rank 1 has 200 -> max 200 # MIN: rank 0 has 5, rank 1 has 10 -> min 5 assert result["train_test/sum_metric"] == 30 @@ -237,7 +261,7 @@ def test_distributed_state_dict_resumption(self): Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. Verifies: - State can be saved after partial updates across ranks - - State can be restored consistently across ranks + - State can be restored consistently across ranks - Continued updates after restore produce identical results - Distributed aggregation works correctly after restoration """ @@ -245,64 +269,95 @@ def test_distributed_state_dict_resumption(self): # Phase 1: Create aggregator and add initial metrics aggregator1 = MetricsAggregator() - + # Each rank contributes different initial values base_value = rank * 100 # rank 0: 0, rank 1: 100 - + initial_metrics = [ Metric("test", "sum_metric", base_value, AggregationType.SUM), Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), Metric("test", "max_metric", base_value * 2, AggregationType.MAX), ] - + # Add some DISTRIBUTION values - each rank adds 3 values for i in range(3): initial_metrics.append( - Metric("test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION) + Metric( + "test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION + ) ) - + # Add CATEGORICAL_COUNT values if rank == 0: - initial_metrics.extend([ - Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), - ]) + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) else: - initial_metrics.extend([ - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - ]) - + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + aggregator1.update(initial_metrics) - + # Save state_dict after initial update state_dict = aggregator1.state_dict() - + # Phase 2: Create new aggregator and restore from state_dict aggregator2 = MetricsAggregator() aggregator2.load_state_dict(state_dict) - + # Verify both aggregators produce identical results after restore result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") - assert result1 == result2, ( - f"Rank {rank}: Aggregators differ after state_dict restore" - ) - - # Phase 3: Add more metrics to both aggregators + assert ( + result1 == result2 + ), f"Rank {rank}: Aggregators differ after state_dict restore" + + # Phase 3: Add more metrics to both aggregators additional_metrics = [ Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), Metric("test", "min_metric", rank * 1000, AggregationType.MIN), ] - + # Update both aggregators with additional metrics aggregator1.update(additional_metrics) aggregator2.update(additional_metrics) - + # Phase 4: Verify final results are identical across both aggregators final_result1 = aggregator1.get_metrics_for_logging(prefix="final") final_result2 = aggregator2.get_metrics_for_logging(prefix="final") - assert final_result1 == final_result2, ( - f"Rank {rank}: Final results differ after continued updates" - ) + assert ( + final_result1 == final_result2 + ), f"Rank {rank}: Final results differ after continued updates" diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index ba644515c2..901234af6f 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -5,19 +5,19 @@ # LICENSE file in the root directory of this source tree. import math -import tempfile import shutil +import tempfile from itertools import islice from pathlib import Path import pytest import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator from torchtune.datasets import HfIterableDataset from .test_iterable_utils import collate_with_metrics, generate_ckpt @@ -79,7 +79,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -99,7 +99,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", # dataset_name not provided - should auto-generate seed=SEED, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", dataset_name="my_dataset", seed=SEED, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -287,6 +287,7 @@ def test_distributed_epoch_boundary_checkpointing(self): # Test multiple epoch boundaries for num_epochs in [0.9, 1.0, 2.5]: + def create_loader_and_aggregator(): dataset = HfIterableDataset( path="json", @@ -295,11 +296,14 @@ def create_loader_and_aggregator(): dataset_name="epoch_test", seed=SEED, shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, ) loader = StatefulDataLoader( - dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics, num_workers=0 + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, ) return loader, MetricsAggregator() @@ -310,21 +314,29 @@ def create_loader_and_aggregator(): samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() total_samples = int(samples_per_rank * num_epochs) total_steps = total_samples // BATCH_SIZE - + if total_steps < 2: - raise ValueError(f"Not enough steps for meaningful test: {total_steps}") + raise ValueError( + f"Not enough steps for meaningful test: {total_steps}" + ) # Split steps between before and after checkpoint steps_before = max(1, total_steps // 2) steps_after = total_steps - steps_before result = generate_ckpt( - loader1, aggregator1, steps_before, steps_after, - resume_dataloader=loader2, resume_aggregator=aggregator2 + loader1, + aggregator1, + steps_before, + steps_after, + resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption - critical for distributed training - orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert orig_post_ids == resumed_ids, ( f"Rank {rank}: Non-deterministic resume for {num_epochs} epochs. " @@ -333,10 +345,12 @@ def create_loader_and_aggregator(): # Verify epoch metric is correctly tracked final_metrics = result["final_metrics"] - expected_epoch = math.floor(num_epochs - 1e-9) # -1e-9 so 1.0 epochs -> 0 - assert final_metrics[f"train_epoch_test/num_epochs"] == expected_epoch, ( - f"Epoch count incorrect for {num_epochs} epochs test scenario" - ) + expected_epoch = math.floor( + num_epochs - 1e-9 + ) # -1e-9 so 1.0 epochs -> 0 + assert ( + final_metrics["train_epoch_test/num_epochs"] == expected_epoch + ), f"Epoch count incorrect for {num_epochs} epochs test scenario" finally: # Clean up temp directory (only rank 0) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index d02912cd0f..98e9207047 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,21 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import tempfile import shutil +import tempfile from itertools import islice from pathlib import Path from unittest.mock import patch import pytest -import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest -from tests.test_utils import gpu_test import torch +import torch.distributed as dist +from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities @@ -87,7 +87,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -270,6 +270,7 @@ def test_distributed_interleaved_checkpointing(self): tmp_path = Path(temp_dir) try: + def create_dataset(): file1 = tmp_path / "ds1.json" file2 = tmp_path / "ds2.json" @@ -277,18 +278,28 @@ def create_dataset(): # Only rank 0 creates the data files if rank == 0: create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 - create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100) # IDs 100-134 + create_test_json_file( + file2, MEDIUM_DATASET_SIZE, offset=100 + ) # IDs 100-134 dist.barrier() # Wait for file creation ds1 = HfIterableDataset( - path="json", data_files=str(file1), split="train", dataset_name="ds1", + path="json", + data_files=str(file1), + split="train", + dataset_name="ds1", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, ) ds2 = HfIterableDataset( - path="json", data_files=str(file2), split="train", dataset_name="ds2", + path="json", + data_files=str(file2), + split="train", + dataset_name="ds2", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, ) # Create interleaved dataset with 70/30 weighting @@ -296,9 +307,10 @@ def create_dataset(): def create_dataloader(dataset): loader = StatefulDataLoader( - dataset, batch_size=BATCH_SIZE, + dataset, + batch_size=BATCH_SIZE, num_workers=0, # Avoid multiprocessing in distributed tests - collate_fn=collate_with_metrics + collate_fn=collate_with_metrics, ) return loader, MetricsAggregator() @@ -307,24 +319,32 @@ def create_dataloader(dataset): loader2, aggregator2 = create_dataloader(create_dataset()) result = generate_ckpt( - loader1, aggregator1, 3, 3, # 3 steps before, 3 steps after checkpoint - resume_dataloader=loader2, resume_aggregator=aggregator2 + loader1, + aggregator1, + 3, + 3, # 3 steps before, 3 steps after checkpoint + resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption - orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert orig_post_ids == resumed_ids, ( f"Rank {rank}: Non-deterministic interleaved resume. " f"This indicates sampling state is not properly preserved." ) - assert result["final_metrics"] == result["resumed_metrics"], ( - "Final metrics don't match resumed metrics - aggregator state issue" - ) + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics don't match resumed metrics - aggregator state issue" # Verify sampling ratio is approximately maintained (80/20 split) all_ids = [] - for batch in result["pre_checkpoint_batches"] + result["post_checkpoint_batches"]: + for batch in ( + result["pre_checkpoint_batches"] + result["post_checkpoint_batches"] + ): all_ids.extend(batch["id"].tolist()) # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index e160345bc1..28c6d8e464 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -9,7 +9,7 @@ import torch from torch.utils.data import DataLoader -from torchtune.data import MetricsAggregator +from torchtune.data.metrics import MetricsAggregator def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: diff --git a/torchtune/data/metrics/__init__.py b/torchtune/data/metrics/__init__.py index 17e359d697..778245f83a 100644 --- a/torchtune/data/metrics/__init__.py +++ b/torchtune/data/metrics/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data.metrics._metric_aggregator import MetricsAggregator from torchtune.data.metrics._metric_agg_handlers import ( AggregationHandler, CategoricalCountAggHandler, @@ -15,6 +14,7 @@ MinAggHandler, SumAggHandler, ) +from torchtune.data.metrics._metric_aggregator import MetricsAggregator from torchtune.data.metrics._metric_transform import ( AggregationType, DefaultTrainingMetricTransform, diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index 1a1557c803..c3415aba7d 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -1,21 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging from abc import ABC, abstractmethod from collections import Counter, deque from dataclasses import dataclass, field -from enum import Enum from typing import Any, Union import torch -from torchtune.data.metrics._metric_transform import Metric, AggregationType +from torchtune.data.metrics._metric_transform import AggregationType, Metric logger = logging.getLogger(__name__) + @dataclass class MetricState: """Mutable state object representing aggregated metric for (dataset, metric) on a single rank. - - Args: + + Attributes: dataset_name (str): Name of the dataset. metric_name (str): Name of the metric. value (float): Current aggregated value, whose meaning depends on the aggregation type @@ -23,15 +29,17 @@ class MetricState: agg_type (AggregationType): Aggregation type. metadata (dict[str, Any]): Additional state like count, list of values, etc. """ + dataset_name: str metric_name: str value: float agg_type: AggregationType metadata: dict[str, Any] = field(default_factory=dict) + class AggregationHandler(ABC): """Base class for handling metric aggregation in MetricsAggregator. - + This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). Each handler is responsible for: - Initializing the state for a new (dataset, metric) pair. @@ -40,31 +48,33 @@ class AggregationHandler(ABC): - Reducing the values from all ranks in a distributed setting. - Serializing and deserializing the metric state for checkpointing. """ - + @abstractmethod - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: """Create a new MetricState for a (dataset_name, metric_name) pair. - + Args: dataset_name (str): Name of the dataset. Especially useful when tracking multiple datasets. metric_name (str): Name of the metric. agg_type (AggregationType): Aggregation type. - + Returns: MetricState: New MetricState for this (dataset_name, metric_name) pair. """ pass - + @abstractmethod def update(self, local_agg_metric: MetricState, metric: Metric) -> None: """Update cumulative MetricState with new metric info. - + Args: local_agg_metric (MetricState): Cumulative state of the aggregation for this metric in the local rank. metric (Metric): Input metric info. """ pass - + @abstractmethod def finalize_local_agg( self, local_agg_metric: MetricState @@ -83,17 +93,15 @@ def finalize_local_agg( A single `MetricState` or a list of them if the metric expands. """ pass - + @abstractmethod - def finalize_dist_agg( - self, local_agg_metrics: list[MetricState] - ) -> MetricState: + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: """ Merge MetricStates from all ranks into final result. - + Args: local_agg_metrics (list[MetricState]): list of MetricStates for this (dataset_name, metric_name) pair. - + Returns: MetricState: Final result for this (dataset_name, metric_name) pair. """ @@ -101,21 +109,27 @@ def finalize_dist_agg( def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert handler-specific metadata to serializable format. - + Args: metadata (dict[str, Any]): AggHandler-specific metadata. + Returns: + dict[str, Any]: Serializable metadata. + Override this when using non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict. """ return metadata.copy() - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Restore handler-specific metadata from serialized format. - + Args: metadata (dict[str, Any]): AggHandler-specific metadata. + Returns: + dict[str, Any]: Deserialized metadata. + Override this to reverse the serialize_metadata transformation. For example, convert list back to deque, dict back to Counter. """ @@ -124,116 +138,138 @@ def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: class SumAggHandler(AggregationHandler): """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, value=0.0, - agg_type=agg_type + agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"SumAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value += metric.value - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: if not local_agg_metrics: raise ValueError("Cannot aggregate empty list of metrics") - + total = sum(metric.value for metric in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=total, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MaxAggHandler(AggregationHandler): """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, - value=float('-inf'), + value=float("-inf"), agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MaxAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value = max(local_agg_metric.value, metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: max_value = max(r.value for r in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=max_value, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MinAggHandler(AggregationHandler): """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, - value=float('inf'), + value=float("inf"), agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MinAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value = min(local_agg_metric.value, metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: min_value = min(r.value for r in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=min_value, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MeanAggHandler(AggregationHandler): """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, metadata={"sum": 0.0, "count": 0}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["sum"] += metric.value local_agg_metric.metadata["count"] += 1 - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: count = local_agg_metric.metadata["count"] - local_agg_metric.value = local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + local_agg_metric.value = ( + local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + ) return local_agg_metric - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) - + return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, @@ -244,43 +280,46 @@ def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState class DistributionAggHandler(AggregationHandler): - """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values + """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values and expands into multiple statistical metrics (mean, min, max, percentiles, std). - - Note: Percentiles and standard deviation are approximated in distributed settings by averaging local - percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a reasonable approximation for monitoring purposes. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + + Raises: + ValueError: If window_size is not positive. """ - + def __init__(self, window_size: int = 1000): - """Initialize handler with specified window size for value retention. - - Args: - window_size (int): Maximum number of recent values to retain for statistics. - """ if window_size <= 0: raise ValueError(f"window_size must be positive, got {window_size}") self.window_size = window_size - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, - metadata={"values": deque(maxlen=self.window_size)} + metadata={"values": deque(maxlen=self.window_size)}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["values"].append(metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: values = list(local_agg_metric.metadata["values"]) if not values: return [] - + return self._compute_distribution_stats(local_agg_metric, values) - + def _compute_distribution_stats( self, local_agg_metric: MetricState, values: list[float] ) -> list[MetricState]: @@ -300,7 +339,9 @@ def _compute_distribution_stats( # Compute all percentiles in one go percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) - p05_val, p50_val, p95_val = torch.quantile(values_tensor, percentile_definitions).tolist() + p05_val, p50_val, p95_val = torch.quantile( + values_tensor, percentile_definitions + ).tolist() # Return multiple MetricStates with proper agg_types for distributed reduction # NOTE: Percentiles use MEAN aggregation which approximates global percentiles @@ -362,7 +403,7 @@ def _compute_distribution_stats( ) ) return metrics - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: raise NotImplementedError( "Metrics with AggregationType.DISTRIBUTION are converted to other " @@ -375,43 +416,49 @@ def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: if "values" in serialized: serialized["values"] = list(serialized["values"]) return serialized - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert list back to deque.""" deserialized = metadata.copy() if "values" in deserialized: - deserialized["values"] = deque(deserialized["values"], maxlen=self.window_size) + deserialized["values"] = deque( + deserialized["values"], maxlen=self.window_size + ) return deserialized class CategoricalCountAggHandler(AggregationHandler): """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values and expands into individual count metrics for each category.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, - metadata={"counts": Counter()} + metadata={"counts": Counter()}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["counts"][metric.value] += 1 - + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: # Expand categorical counts into individual metrics results = [] for category, count in local_agg_metric.metadata["counts"].items(): - results.append(MetricState( - dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_{category}_count", - value=count, - agg_type=AggregationType.SUM - )) + results.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_{category}_count", + value=count, + agg_type=AggregationType.SUM, + ) + ) return results - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: raise NotImplementedError( "Metrics with AggregationType.CATEGORICAL_COUNT are converted to other " @@ -424,10 +471,10 @@ def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: if "counts" in serialized: serialized["counts"] = dict(serialized["counts"]) return serialized - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert dict back to Counter.""" deserialized = metadata.copy() if "counts" in deserialized: deserialized["counts"] = Counter(deserialized["counts"]) - return deserialized \ No newline at end of file + return deserialized diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index c07f0dea36..633d5c6b80 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -6,7 +6,7 @@ import ast from collections import defaultdict -from typing import Any, tuple +from typing import Any import torch.distributed as dist @@ -14,63 +14,69 @@ AggregationHandler, CategoricalCountAggHandler, DistributionAggHandler, - MetricState, MaxAggHandler, MeanAggHandler, + MetricState, MinAggHandler, SumAggHandler, ) -from torchtune.data.metrics._metric_transform import Metric, AggregationType +from torchtune.data.metrics._metric_transform import AggregationType, Metric + class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - - Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) + + Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) has its own handler. Maintains only one state per (dataset, metric) pair. - + When preparing for logging, uses a two-phase approach: 1. Local aggregation: Each rank aggregates its metrics independently 2. Distributed reduction: Results combined across ranks - + The aggregator is checkpointable and restores from state_dict for training resumption. - + Args: dist_window_size (int): Window size for DistributionAggHandler tracking. - + Example: >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType - >>> + >>> >>> # Create aggregator >>> aggregator = MetricsAggregator() - >>> + >>> >>> # Sample metrics from different batches >>> batch1_metrics = [ ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ... ] - >>> + >>> >>> batch2_metrics = [ ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ... ] - >>> + >>> >>> # Update with metrics >>> aggregator.update(batch1_metrics) >>> aggregator.update(batch2_metrics) - >>> + >>> >>> # Get final results >>> results = aggregator.get_metrics_for_logging(prefix="train") >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + + Raises: + ValueError: If dist_window_size is not positive. """ - + def __init__(self, dist_window_size: int = 1000): if dist_window_size <= 0: - raise ValueError(f"dist_window_size must be positive, got {dist_window_size}") - + raise ValueError( + f"dist_window_size must be positive, got {dist_window_size}" + ) + # Storage: {(dataset, metric): MetricState} - O(unique metrics) not O(samples) self._metric_states: dict[tuple[str, str], MetricState] = {} self._dist_window_size = dist_window_size - + # Create handler registry - all handlers initialized upfront self._handlers: dict[AggregationType, AggregationHandler] = { AggregationType.SUM: SumAggHandler(), @@ -80,59 +86,66 @@ def __init__(self, dist_window_size: int = 1000): AggregationType.DISTRIBUTION: DistributionAggHandler(dist_window_size), AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), } - - def register_handler(self, agg_type: AggregationType, handler: AggregationHandler) -> None: + + def register_handler( + self, agg_type: AggregationType, handler: AggregationHandler + ) -> None: """Register custom aggregation handler for specified type. - + Args: agg_type (AggregationType): The aggregation type to handle handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface """ self._handlers[agg_type] = handler - + def update(self, metrics: list[Metric]) -> None: """Update (dataset_name, metric_name) metric state with new values. - + Args: metrics (list[Metric]): List of metrics to update the state with + + Raises: + ValueError: If no handler is registered for a metric's aggregation type. """ for metric in metrics: metric_key = (metric.dataset_name, metric.name) handler = self._handlers.get(metric.agg_type) - + if handler is None: - raise ValueError(f"No handler registered for aggregation type: {metric.agg_type}") - + raise ValueError( + f"No handler registered for aggregation type: {metric.agg_type}" + ) + if metric_key not in self._metric_states: self._metric_states[metric_key] = handler.initialize_metric_state( metric.dataset_name, metric.name, metric.agg_type ) - + local_agg_metric = self._metric_states[metric_key] handler.update(local_agg_metric, metric) # Mutates local_agg_metric - + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: """Get final metrics for logging in standard format. - + Args: prefix (str): Prefix for metric names in the returned dictionary - + Returns: - dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" - and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, + dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" + and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, `metric_name="loss"`, the key would be `train_alpaca/loss`. """ final_results = self._compute_unified_metrics() - + return { - f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value + f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value for result in final_results } - + def _compute_unified_metrics(self) -> list[MetricState]: """ Compute metrics handling both local and distributed cases uniformly. - + Returns: list[MetricState]: Final results ready for logging """ @@ -141,32 +154,36 @@ def _compute_unified_metrics(self) -> list[MetricState]: for local_agg_metric in self._metric_states.values(): handler = self._handlers[local_agg_metric.agg_type] prepared = handler.finalize_local_agg(local_agg_metric) - if isinstance(prepared, list): # Distribution/categorical expands to multiple + if isinstance( + prepared, list + ): # Distribution/categorical expands to multiple prepared_results.extend(prepared) else: prepared_results.append(prepared) - + # Step 2: Apply distributed reduction if needed if dist.is_initialized() and dist.get_world_size() > 1: prepared_results = self._finalize_dist_agg(prepared_results) - + return prepared_results - - def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[MetricState]: + + def _finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> list[MetricState]: """Apply distributed reduction to local results. - + Args: local_agg_metrics (list[MetricState]): (dataset_name, metric_name) metric pairs from this rank - + Returns: list[MetricState]: Reduced results combining all ranks """ world_size = dist.get_world_size() - + # Gather all results from all ranks all_results = [None] * world_size dist.all_gather_object(all_results, local_agg_metrics) - + # Group by (dataset_name, metric_name) for reduction grouped = defaultdict(list) for rank_results in all_results: @@ -174,98 +191,100 @@ def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[Metri for result in rank_results: result_key = (result.dataset_name, result.metric_name) grouped[result_key].append(result) - + # Apply handler-specific distributed reduction reduced_results = [] for result_key, results_list in grouped.items(): if not results_list: continue # Skip empty groups - + # All results for a key should have same agg_type agg_type = results_list[0].agg_type handler = self._handlers[agg_type] reduced_result = handler.finalize_dist_agg(results_list) reduced_results.append(reduced_result) - + return reduced_results - + def state_dict(self) -> dict[str, Any]: """Serialize aggregator state for checkpointing. - + Returns: dict[str, Any]: Serializable dictionary containing all aggregator state """ serializable_state = {} required_agg_types = set() # Track aggregation types used in saved states - + for metric_key, local_agg_metric in self._metric_states.items(): # Get handler for this result's aggregation type handler = self._handlers[local_agg_metric.agg_type] required_agg_types.add(local_agg_metric.agg_type) - + # Convert MetricState to serializable dict result_dict = { "dataset_name": local_agg_metric.dataset_name, "metric_name": local_agg_metric.metric_name, "value": local_agg_metric.value, "agg_type": local_agg_metric.agg_type, - "metadata": handler.serialize_metadata(local_agg_metric.metadata) + "metadata": handler.serialize_metadata(local_agg_metric.metadata), } - + # Convert tuple key to string for JSON compatibility serializable_state[str(metric_key)] = result_dict - + return { "state": serializable_state, "dist_window_size": self._dist_window_size, - "required_agg_types": list(required_agg_types) # Save which handlers are needed + "required_agg_types": list( + required_agg_types + ), # Save which handlers are needed } - + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load aggregator state from checkpoint. - + Args: state_dict (dict[str, Any]): Dictionary containing serialized aggregator state - + Raises: ValueError: If required handlers are missing after checkpoint restore """ self._dist_window_size = state_dict.get("dist_window_size", 1000) - + # Sanity check: Ensure all required handlers are available required_agg_types = state_dict.get("required_agg_types", []) missing_handlers = [] for agg_type in required_agg_types: if agg_type not in self._handlers: missing_handlers.append(agg_type) - + if missing_handlers: raise ValueError( f"Missing handlers for aggregation types: {missing_handlers}. " f"Custom handlers must be re-registered before checkpoint restore." ) - + deserialized_state = {} for key_str, result_dict in state_dict["state"].items(): # Convert string keys back to tuples metric_key = ast.literal_eval(key_str) - + # Get handler for this aggregation type agg_type = result_dict["agg_type"] handler = self._handlers[agg_type] - + # Restore metadata using handler-specific deserialization metadata = handler.deserialize_metadata(result_dict["metadata"]) - + # Create MetricState from dict local_agg_metric = MetricState( dataset_name=result_dict["dataset_name"], metric_name=result_dict["metric_name"], value=result_dict["value"], agg_type=result_dict["agg_type"], - metadata=metadata + metadata=metadata, ) - + deserialized_state[metric_key] = local_agg_metric - + self._metric_states = deserialized_state diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 0e252cb2e6..9ae73488e4 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -7,10 +7,11 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Mapping, Optional, Union from torchtune.modules.transforms import Transform + @dataclass(frozen=True) class Metric: dataset_name: str @@ -18,6 +19,7 @@ class Metric: value: Union[int, float, str] agg_type: "AggregationType" + class AggregationType(Enum): """Defines how a metric's value should be aggregated.""" @@ -28,11 +30,12 @@ class AggregationType(Enum): MAX = "max" MIN = "min" + class MetricTransform(Transform): """Applied to each sample to generate per-sample metrics for training tracking. - - Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation - of concerns ensures metrics are correctly aggregated even with multiple dataloader + + Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader workers and in distributed settings.""" def __init__(self): @@ -42,10 +45,10 @@ def __init__(self): def set_dataset_name(self, dataset_name: str) -> None: """Called by dataset to set the namespace for metrics. - + The dataset name is used to differentiate multiple datasets stats, e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". - + Args: dataset_name (str): Name of the dataset for metric namespacing """ @@ -53,21 +56,22 @@ def set_dataset_name(self, dataset_name: str) -> None: # Create a partial to make it easier to create new metrics self.new_metric = partial(Metric, dataset_name=dataset_name) - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. - + Args: - sample (dict[str, Any]): The sample dictionary to generate metrics from - + sample (Mapping[str, Any]): The sample dictionary to generate metrics from + Returns: list[Metric]: List of metrics generated for this sample + + Raises: + NotImplementedError: If subclass does not implement this method. """ - raise NotImplementedError( - "Subclasses must implement _generate_metrics method" - ) + raise NotImplementedError("Subclasses must implement _generate_metrics method") - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply transform to sample, adding generated metrics. """ + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """Apply transform to sample, adding generated metrics.""" if self.dataset_name is None or self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -85,18 +89,18 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: class DefaultTrainingMetricTransform(MetricTransform): """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. - + For details about MetricTransform base class behavior, see the parent class docstring. - + Tracked metrics: - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) - + Example: >>> transform = DefaultTrainingMetricTransform() >>> transform.set_dataset_name("alpaca") - >>> + >>> >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens >>> metrics = transform._generate_metrics(sample) >>> # Creates: @@ -107,7 +111,12 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> # ] """ - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + if self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + # Determine token key token_key = "tokens" if "tokens" in sample else "input_ids" token_len = len(sample.get(token_key, [])) diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md index 5b79c5e98b..6c6c413246 100644 --- a/torchtune/data/metrics/readme.md +++ b/torchtune/data/metrics/readme.md @@ -102,12 +102,12 @@ aggregator = MetricsAggregator() # Sample metrics from different batches batch1_metrics = [ Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ] batch2_metrics = [ Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ] # Update with metrics @@ -125,7 +125,7 @@ Pluggable strategies for different aggregation patterns. ``` AggregationHandler (ABC) ├── SumAggHandler # value += metric.value -├── MeanAggHandler # tracks sum and count +├── MeanAggHandler # tracks sum and count ├── MaxAggHandler # value = max(value, metric.value) ├── MinAggHandler # value = min(value, metric.value) ├── DistributionAggHandler # maintains value window + stats @@ -138,18 +138,18 @@ class CustomAggHandler(AggregationHandler): def initialize_metric_state(self, dataset_name, metric_name, agg_type): return MetricState( dataset_name=dataset_name, - metric_name=metric_name, + metric_name=metric_name, value=, # should change agg_type=agg_type, metadata={} # may need to change ) - + def update(self, local_agg_metric, metric): ... - + def finalize_local_agg(self, local_agg_metric): ... - + def finalize_dist_agg(self, local_agg_metrics): ... @@ -173,4 +173,4 @@ Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, AllGather + Reduce ↓ Final Results [(ds1, metric1), (ds1, metric2)] -``` \ No newline at end of file +``` diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 81710169d0..7aac8adcc4 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -12,7 +12,11 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from torchtune.data.metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.data.metrics import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, +) from torchtune.datasets._iterable_base import TuneIterableDataset logger = logging.getLogger(__name__) @@ -73,7 +77,7 @@ def __init__( self._weight = weight # TODO: make it a property? # Create default transform if not provided - self._metric_transform = metric_transform or StandardMetricTransform() + self._metric_transform = metric_transform or DefaultTrainingMetricTransform() # Auto-generate dataset name if not provided, ensuring it's always a string. if dataset_name is None: diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 72289c14dc..a0aab0b27b 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -12,7 +12,7 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages -from torchtune.data.metrics import StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.modules.transforms import Transform @@ -266,7 +266,7 @@ def sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, output_transform=output_transform, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, weight=weight, seed=seed, From 853147b5a5c2945661b6016da620db46e325dad1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 13:49:09 -0700 Subject: [PATCH 26/48] optimize SFTOutputTransform --- torchtune/data/metrics/_metric_transform.py | 12 ++-- torchtune/datasets/_sft.py | 73 ++++++++++++--------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 9ae73488e4..529fff8c5c 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Callable, Optional, Union from torchtune.modules.transforms import Transform @@ -32,7 +32,7 @@ class AggregationType(Enum): class MetricTransform(Transform): - """Applied to each sample to generate per-sample metrics for training tracking. + """Applied to each dataset sample to generate per-sample metrics for training tracking. Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation of concerns ensures metrics are correctly aggregated even with multiple dataloader @@ -56,11 +56,11 @@ def set_dataset_name(self, dataset_name: str) -> None: # Create a partial to make it easier to create new metrics self.new_metric = partial(Metric, dataset_name=dataset_name) - def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. Args: - sample (Mapping[str, Any]): The sample dictionary to generate metrics from + sample (dict[str, Any]): The sample dictionary to generate metrics from Returns: list[Metric]: List of metrics generated for this sample @@ -70,7 +70,7 @@ def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: """ raise NotImplementedError("Subclasses must implement _generate_metrics method") - def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transform to sample, adding generated metrics.""" if self.dataset_name is None or self.new_metric is None: raise RuntimeError( @@ -111,7 +111,7 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> # ] """ - def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: if self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index a0aab0b27b..f7638bb609 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Optional import numpy as np +import torch from datasets import load_dataset from torch.utils.data import Dataset @@ -145,7 +146,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform - def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if self._message_transform is not None: transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: @@ -183,40 +184,50 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: class SFTOutputTransform(Transform): - """ - Output transform to be used in SFT recipes as an input to TuneIterableDataset. - It takes tokenized inputs with "tokens" and "mask" keys and - creates the "labels" key for SFT training. - - The labels are created by: - 1. Shifting tokens by 1 position (for autoregressive training) - 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX - 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end + """Applied to each dataset sample to build the `"labels"` tensor for causal-LM SFT training. + + Expects sample to contain 1-D torch tensors + "tokens": token IDs, dtype=torch.long + "mask": bool/int where **True** marks positions to ignore + + If they are not tensors, they are converted to tensors. + + Produces ``"labels"`` of the same shape such that + labels[t] = tokens[t+1] # shift left + labels[t] = IGNORE_IDX if mask[t+1] # respect mask + labels[-1] = IGNORE_IDX # last token has no target + + All ops are vectorised; only one fresh tensor (`labels`) is allocated. """ - def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: - # Create a copy to avoid modifying the original - tokenized_dict = dict(sample) + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): - keys_str = ", ".join(tokenized_dict.keys()) - raise ValueError( - f"SFTOutputTransform expects 'tokens' and 'mask' keys. " - f"Got keys: {keys_str}" - ) + tokens = sample["tokens"] + mask = sample["mask"] - # Create labels for SFT training - tokenized_dict["labels"] = list( - np.where( - tokenized_dict["mask"][1:], - CROSS_ENTROPY_IGNORE_IDX, - tokenized_dict["tokens"][1:], - ) - ) - tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) - assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + # Sanity checks + if not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens) + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask) - return tokenized_dict + if tokens.ndim != 1 or mask.ndim != 1: + raise ValueError("Both 'tokens' and 'mask' must be 1-D tensors.") + + # build labels + # pre-fill with IGNORE so we don’t need extra assignments later + labels = tokens.new_full(tokens.shape, CROSS_ENTROPY_IGNORE_IDX) + + # left-shift via cheap views (no copy) + labels[:-1].copy_(tokens[1:]) + + # apply mask in-place (single fused kernel on GPU/CPU) + labels[:-1].masked_fill_(mask[1:].bool(), CROSS_ENTROPY_IGNORE_IDX) + + # return a shallow-copied mapping so the original sample stays intact + out = dict(sample) + out["labels"] = labels + return out def sft_iterable_dataset( From 96bc3172a0e437e2b574aa79ba4d32a3223b38ae Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 18:04:27 -0400 Subject: [PATCH 27/48] use ds.sampling_weight --- recipes/full_finetune_distributed.py | 3 -- tests/torchtune/datasets/test_hf_iterable.py | 13 +++++- tests/torchtune/datasets/test_interleaved.py | 48 +++++++++++--------- torchtune/datasets/_hf_iterable.py | 4 ++ torchtune/datasets/_interleaved.py | 20 ++++++-- torchtune/datasets/_iterable_base.py | 12 +++++ 6 files changed, 69 insertions(+), 31 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 4c41a81a5b..fadb3f7c23 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -775,7 +775,6 @@ def _setup_data( # 1. Create all datasets iterable_datasets = [] - weights = [] cfg_dataset_list = cfg_dataset if not isinstance(cfg_dataset_list, ListConfig): cfg_dataset_list = [cfg_dataset_list] @@ -783,13 +782,11 @@ def _setup_data( for ds_cfg in cfg_dataset_list: ds = config.instantiate(ds_cfg, model_transform=self._tokenizer) iterable_datasets.append(ds) - weights.append(ds_cfg.get("weight", 1.0)) # 2. Interleave datasets if any if len(iterable_datasets) > 1: ds = InterleavedDataset( datasets=iterable_datasets, - weights=weights, seed=self.seed, ) else: diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 901234af6f..94b55b87be 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -105,13 +105,18 @@ def test_default_dataset_name(self, small_dataset_file): # Should generate name from path and split assert dataset.dataset_name == "json_train" + # Test default sampling weight + assert dataset.sampling_weight == 1.0 + assert isinstance(dataset.sampling_weight, float) - # Test giving a name + # Test giving a name and custom weight + custom_weight = 2.5 dataset2 = HfIterableDataset( path="json", data_files=small_dataset_file, split="train", dataset_name="my_dataset", + weight=custom_weight, seed=SEED, metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, @@ -119,6 +124,8 @@ def test_default_dataset_name(self, small_dataset_file): # Should generate name from path and split assert dataset2.dataset_name == "my_dataset" + # Test custom sampling weight + assert dataset2.sampling_weight == custom_weight @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) def test_epoch_boundaries_and_checkpointing( @@ -198,9 +205,11 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] - # Shuffled epochs should have different order + # Extract IDs for comparison first_epoch_ids = [sample["id"] for sample in first_epoch_samples] second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + + # Shuffled epochs should have different order assert first_epoch_ids != list( range(SMALL_DATASET_SIZE) ), f"Shuffled should not be sorted, got {first_epoch_ids}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 98e9207047..38d92bc2d3 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -101,30 +101,35 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" # Test duplicate dataset names - ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate") - ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate") + ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) + ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) with pytest.raises(ValueError, match="Duplicate dataset names detected"): - InterleavedDataset(datasets=[ds1, ds2], weights=[0.5, 0.5], seed=SEED) + InterleavedDataset(datasets=[ds1, ds2], seed=SEED) # Test weight normalization (should work with warning) - ds3 = dataset_factory(small_dataset_file, dataset_name="ds3") - ds4 = dataset_factory(small_dataset_file, dataset_name="ds4") + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( datasets=[ds3, ds4], - weights=[0.5, 1.5], seed=SEED, dataset_name="test_interleaved", # Sum = 2.0 != 1.0 ) - # Check that weights were normalized - assert torch.allclose(interleaved._weights, torch.tensor([0.25, 0.75])) - mock_warning.assert_called_once() assert interleaved.dataset_name == "test_interleaved" + # Test sampling_weight property returns normalized weights + sampling_weights = interleaved.sampling_weight + assert isinstance(sampling_weights, dict) + assert "ds3" in sampling_weights + assert "ds4" in sampling_weights + assert abs(sampling_weights["ds3"] - 0.25) < 1e-6 + assert abs(sampling_weights["ds4"] - 0.75) < 1e-6 + assert abs(sum(sampling_weights.values()) - 1.0) < 1e-6 + def test_sampling_ratios( self, dataset_factory, small_dataset_file, medium_dataset_file ): @@ -132,12 +137,11 @@ def test_sampling_ratios( # Create two datasets with distinct ID ranges # ds1 has IDs 0-22 (small dataset) # ds2 has IDs 100-134 (medium dataset with offset) - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) # Test with 70/30 weighting - weights = [0.7, 0.3] - interleaved = InterleavedDataset([ds1, ds2], weights, seed=SEED) + interleaved = InterleavedDataset([ds1, ds2], seed=SEED) # Collect 300 samples sample_count = 300 @@ -162,10 +166,10 @@ def test_metrics_aggregation( self, dataset_factory, small_dataset_file, medium_dataset_file ): """Tests that metrics from all child datasets are collected and aggregated.""" - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.8) - interleaved = InterleavedDataset([ds1, ds2], [0.2, 0.8], seed=SEED) + interleaved = InterleavedDataset([ds1, ds2], seed=SEED) aggregator = MetricsAggregator() # Process some samples @@ -203,9 +207,9 @@ def test_checkpointing( """Tests that interleaved dataset checkpointing preserves sampling state.""" def create_interleaved(): - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") - return InterleavedDataset([ds1, ds2], [0.7, 0.3], seed=SEED) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) + return InterleavedDataset([ds1, ds2], seed=SEED) # Original run interleaved1 = create_interleaved() @@ -291,6 +295,7 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, + weight=0.8, ) ds2 = HfIterableDataset( path="json", @@ -300,10 +305,11 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, + weight=0.2, ) - # Create interleaved dataset with 70/30 weighting - return InterleavedDataset([ds1, ds2], [0.8, 0.2], seed=SEED) + # Create interleaved dataset with 80/20 weighting + return InterleavedDataset([ds1, ds2], seed=SEED) def create_dataloader(dataset): loader = StatefulDataLoader( diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 7aac8adcc4..3949e93508 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -108,6 +108,10 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name + @property + def sampling_weight(self) -> float: + return self._weight + def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" if self._message_transform is not None: diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 0245d4e94e..71a8f72674 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -17,14 +17,13 @@ class InterleavedDataset(TuneIterableDataset): - """Infinitely interleaves multiple TuneIterableDatasets according to a list of weights. - - The weights are normalized to sum to 1.0. + """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. + - The weights are extracted from each dataset's sampling_weight property and normalized to sum to 1.0. - This dataset is responsible for managing the state of its child datasets to ensure correct checkpointing and resumption. Args: datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. - weights (list[float]): list of weights for each dataset. Must sum to 1.0. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". @@ -35,7 +34,6 @@ class InterleavedDataset(TuneIterableDataset): def __init__( self, datasets: list[TuneIterableDataset], - weights: list[float], seed: int, dataset_name: str = "interleaved_dataset", ): @@ -62,8 +60,16 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) + # Extract weights from datasets' sampling_weight property + weights = [] + for ds in datasets: + weight = ds.sampling_weight + if isinstance(weight, dict): + # For composite datasets, sum up their weights + weight = sum(weight.values()) + weights.append(weight) + # Normalize weights to sum to 1 - # TODO: make it a property? rely on ds.weight? total_weight = sum(weights) self._weights = torch.tensor( [w / total_weight for w in weights], dtype=torch.float @@ -78,6 +84,10 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name + @property + def sampling_weight(self) -> dict[str, float]: + return {name: weight.item() for name, weight in zip(self._dataset_names, self._weights)} + def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" child_iters = {name: iter(ds) for name, ds in self._datasets.items()} diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index f0821dc3f1..6630761f0d 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -24,6 +24,18 @@ def dataset_name(self) -> str: """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" pass + @property + @abstractmethod + def sampling_weight(self) -> float | dict[str, float]: + """ + Returns the sampling weight for this dataset when used in multi-dataset scenarios. + + For leaf datasets: returns a float representing the relative weight. + For composite datasets: returns a dict mapping child dataset names to their weights. + Used by interleaving logic to determine sampling probabilities. + """ + pass + @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: """ From 3c9d161629d55cee1a9e55b49df98734f985db2f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 18:31:02 -0400 Subject: [PATCH 28/48] add sampling log to interlead dataset --- tests/torchtune/datasets/test_interleaved.py | 19 +++++++++++++++++ torchtune/datasets/_interleaved.py | 22 +++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 38d92bc2d3..361714523c 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -243,6 +243,25 @@ def create_interleaved(): result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" + # Test sampling log functionality + # Check that sampling log contains tuples of (iteration_count, dataset_name) + state_dict = interleaved1.state_dict() + sampling_log = state_dict["sampling_log"] + iteration_count = state_dict["iteration_count"] + + assert len(sampling_log) > 0, "Sampling log should not be empty" + assert iteration_count > 0, "Iteration count should be positive" + + # Check sampling ratios match expected weights (70/30) + ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") + ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") + total_samples = ds1_count + ds2_count + + ds1_ratio = ds1_count / total_samples + ds2_ratio = ds2_count / total_samples + + assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" + assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" class TestDistributedInterleavedDataset(FSDPTest): @property diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 71a8f72674..13859dd2d0 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -7,6 +7,7 @@ import collections import logging import math +from collections import deque from typing import Any, Iterator import torch @@ -26,7 +27,8 @@ class InterleavedDataset(TuneIterableDataset): datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". - + sampling_log_maxlen (int): Maximum length of the sampling log. + Raises: ValueError: If duplicate dataset names are detected in the provided datasets. """ @@ -36,8 +38,10 @@ def __init__( datasets: list[TuneIterableDataset], seed: int, dataset_name: str = "interleaved_dataset", + sampling_log_maxlen: int = 10000, ): self._dataset_name = dataset_name + self._sampling_log_maxlen = sampling_log_maxlen # Preserve original order for weighted sampling self._dataset_names = [ds.dataset_name for ds in datasets] @@ -60,6 +64,10 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) + # Track sampling decisions for debugging and analysis + self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) + self._iteration_count = 0 + # Extract weights from datasets' sampling_weight property weights = [] for ds in datasets: @@ -101,6 +109,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Sample an index, then get the name for safe lookup ds_name = self._dataset_names[ds_idx] + # Log this sampling decision + self._sampling_log.append((self._iteration_count, ds_name)) + self._iteration_count += 1 + try: sample = next(child_iters[ds_name]) yield sample @@ -123,6 +135,8 @@ def state_dict(self) -> dict[str, Any]: return { "sampling_generator_state": self._sampling_generator.get_state(), "child_states": child_states, + "sampling_log": list(self._sampling_log), + "iteration_count": self._iteration_count, } def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -133,3 +147,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: if name in child_states: # Pass the raw state dict to the child ds.load_state_dict(child_states[name]) + + # Load sampling log and iteration count + self._sampling_log = deque( + state_dict.get("sampling_log", []), maxlen=self._sampling_log_maxlen + ) + self._iteration_count = state_dict.get("iteration_count", 0) From 4804663bf375baa9cd9bfe0eff83fa4831e0f9ac Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 22:32:09 -0400 Subject: [PATCH 29/48] fix nested interleave --- tests/torchtune/datasets/test_interleaved.py | 8 +++++++- torchtune/datasets/_interleaved.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 361714523c..4dedfca5ee 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -107,10 +107,16 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): with pytest.raises(ValueError, match="Duplicate dataset names detected"): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) - # Test weight normalization (should work with warning) + # Test nested interleaved datasets are rejected ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) + nested_interleaved = InterleavedDataset([ds3, ds4], seed=SEED, dataset_name="nested") + + with pytest.raises(ValueError, match="returned a dict for sampling_weight"): + # This should fail because nested_interleaved.sampling_weight returns a dict + InterleavedDataset([nested_interleaved, ds3], seed=SEED) + # Test weight normalization (should work with warning) with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( datasets=[ds3, ds4], diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 13859dd2d0..46c4cfb7dd 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -68,13 +68,16 @@ def __init__( self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) self._iteration_count = 0 - # Extract weights from datasets' sampling_weight property + # Extract weights from child datasets weights = [] for ds in datasets: weight = ds.sampling_weight if isinstance(weight, dict): - # For composite datasets, sum up their weights - weight = sum(weight.values()) + raise ValueError( + f"Child dataset '{ds.dataset_name}' returned a dict for sampling_weight, " + f"indicating it's a composite dataset (likely InterleavedDataset). " + f"Nested interleaving is not supported. Please flatten the dataset hierarchy." + ) weights.append(weight) # Normalize weights to sum to 1 From 2fe4b401bf6f98931f9cd5eb60780bfa8503a731 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 22:57:24 -0400 Subject: [PATCH 30/48] changes to TuneIterableDataset --- torchtune/datasets/_iterable_base.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 6630761f0d..dafa8d9fce 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -7,12 +7,8 @@ from abc import ABC, abstractmethod from typing import Any, Iterator -from torch.utils.data import IterableDataset - - -class TuneIterableDataset(IterableDataset, ABC): - """ - Abstract base class for all torchtune iterable datasets. +class TuneIterableDataset(ABC): + """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset implementations to ensure they are compatible with the training loop, checkpointing, and metric logging systems. @@ -25,22 +21,19 @@ def dataset_name(self) -> str: pass @property - @abstractmethod def sampling_weight(self) -> float | dict[str, float]: - """ - Returns the sampling weight for this dataset when used in multi-dataset scenarios. + """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. For leaf datasets: returns a float representing the relative weight. For composite datasets: returns a dict mapping child dataset names to their weights. - Used by interleaving logic to determine sampling probabilities. """ - pass + return 1.0 @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: - """ - Returns an infinite iterator over the dataset. Each implementation is responsible - for its own iteration logic, including shuffling and making it an infinite stream. + """Returns an infinite iterator over the dataset. Each implementation is responsible + for its own iteration logic, including shuffling, distribution of data across ranks, + and making it an infinite stream. """ pass From 72211c99cac9254c785ead30ce7bfddcf0229d60 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 20:01:47 -0700 Subject: [PATCH 31/48] add IterableDataset back --- torchtune/datasets/_iterable_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index dafa8d9fce..92290d658e 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -7,7 +7,10 @@ from abc import ABC, abstractmethod from typing import Any, Iterator -class TuneIterableDataset(ABC): +from torch.utils.data import IterableDataset + + +class TuneIterableDataset(IterableDataset, ABC): """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset implementations to ensure they are compatible with the training loop, @@ -23,7 +26,7 @@ def dataset_name(self) -> str: @property def sampling_weight(self) -> float | dict[str, float]: """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. - + For leaf datasets: returns a float representing the relative weight. For composite datasets: returns a dict mapping child dataset names to their weights. """ @@ -33,8 +36,7 @@ def sampling_weight(self) -> float | dict[str, float]: def __iter__(self) -> Iterator[dict[str, Any]]: """Returns an infinite iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling, distribution of data across ranks, - and making it an infinite stream. - """ + and making it an infinite stream.""" pass @abstractmethod From b350ac7b5e13f9121bbf4d8397eebebc24cbcd8b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 21:59:44 -0400 Subject: [PATCH 32/48] nested interleaved + dataset.info --- tests/torchtune/datasets/test_hf_iterable.py | 11 +- tests/torchtune/datasets/test_interleaved.py | 294 +++++++++++++------ torchtune/datasets/__init__.py | 4 +- torchtune/datasets/_hf_iterable.py | 33 +-- torchtune/datasets/_interleaved.py | 141 ++++----- torchtune/datasets/_iterable_base.py | 47 ++- 6 files changed, 327 insertions(+), 203 deletions(-) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 94b55b87be..ec3ca26936 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -104,10 +104,9 @@ def test_default_dataset_name(self, small_dataset_file): ) # Should generate name from path and split - assert dataset.dataset_name == "json_train" + assert dataset.info.name == "json_train" # Test default sampling weight - assert dataset.sampling_weight == 1.0 - assert isinstance(dataset.sampling_weight, float) + assert dataset.info.weight == 1.0 # Test giving a name and custom weight custom_weight = 2.5 @@ -122,10 +121,10 @@ def test_default_dataset_name(self, small_dataset_file): num_shards_per_rank=4, ) - # Should generate name from path and split - assert dataset2.dataset_name == "my_dataset" + # Should use provided name and weight + assert dataset2.info.name == "my_dataset" # Test custom sampling weight - assert dataset2.sampling_weight == custom_weight + assert dataset2.info.weight == custom_weight @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) def test_epoch_boundaries_and_checkpointing( diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 4dedfca5ee..c2ea53c970 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -27,6 +27,7 @@ # Test Constants SMALL_DATASET_SIZE = 23 MEDIUM_DATASET_SIZE = 35 +LARGE_DATASET_SIZE = 47 SEED = 42 BATCH_SIZE = 5 @@ -70,6 +71,13 @@ def medium_dataset_file(tmp_data_dir): return str(path) +@pytest.fixture +def large_dataset_file(tmp_data_dir): + path = tmp_data_dir / "large_data.json" + create_test_json_file(path, LARGE_DATASET_SIZE, offset=1000) + return str(path) + + @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" @@ -100,122 +108,201 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" - # Test duplicate dataset names - ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) - ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) + + # Test 1: Duplicate dataset names should raise an error + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + ds2 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - with pytest.raises(ValueError, match="Duplicate dataset names detected"): + with pytest.raises(ValueError, match="Duplicate dataset names found in hierarchy"): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) - # Test nested interleaved datasets are rejected - ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) - ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) - nested_interleaved = InterleavedDataset([ds3, ds4], seed=SEED, dataset_name="nested") + # Test 2: Nested interleaved datasets should be supported + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=1.5) + interleaved_child = InterleavedDataset([ds1, ds3], seed=SEED, dataset_name="interleaved_child") - with pytest.raises(ValueError, match="returned a dict for sampling_weight"): - # This should fail because nested_interleaved.sampling_weight returns a dict - InterleavedDataset([nested_interleaved, ds3], seed=SEED) - - # Test weight normalization (should work with warning) + # Create a parent interleaved dataset containing the nested one + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=0.5) + + # Test 3: Weight normalization should work with a warning with patch("logging.Logger.warning") as mock_warning: - interleaved = InterleavedDataset( - datasets=[ds3, ds4], - seed=SEED, - dataset_name="test_interleaved", # Sum = 2.0 != 1.0 - ) - - - assert interleaved.dataset_name == "test_interleaved" - - # Test sampling_weight property returns normalized weights - sampling_weights = interleaved.sampling_weight - assert isinstance(sampling_weights, dict) - assert "ds3" in sampling_weights - assert "ds4" in sampling_weights - assert abs(sampling_weights["ds3"] - 0.25) < 1e-6 - assert abs(sampling_weights["ds4"] - 0.75) < 1e-6 - assert abs(sum(sampling_weights.values()) - 1.0) < 1e-6 - + interleaved_parent = InterleavedDataset([interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent") + + # Verify that a warning was logged about weight normalization + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "normalized" in warning_message.lower() + + # Verify the hierarchical structure is correct + assert interleaved_parent.info.name == "interleaved_parent" + assert len(interleaved_parent.info.children) == 2 + assert interleaved_parent.info.children[0].name == "interleaved_child" + assert interleaved_parent.info.children[1].name == "ds4" + + # Verify the nested structure within the nested dataset + nested_info = interleaved_parent.info.children[0] + assert len(nested_info.children) == 2 + assert nested_info.children[0].name == "ds1" + assert nested_info.children[1].name == "ds3" + + # Verify that sampling weights are normalized to sum to 1.0 + # Access the internal normalized weights tensor + normalized_weights = interleaved_parent._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 2 + + # nested: 2.0/2.5 = 0.8, ds4: 0.5/2.5 = 0.2 + assert abs(normalized_weights[0].item() - 0.8) < 1e-6 + assert abs(normalized_weights[1].item() - 0.2) < 1e-6 + assert abs(normalized_weights.sum().item() - 1.0) < 1e-6 + + # Verify that original weights in info remain unnormalized + child_weights = [child.weight for child in interleaved_parent.info.children] + assert abs(child_weights[0] - 2.0) < 1e-6 # nested original weight + assert abs(child_weights[1] - 0.5) < 1e-6 # ds4 original weight + + + def test_single_dataset(self, dataset_factory, small_dataset_file): + """Tests that InterleavedDataset works correctly with a single dataset.""" + # Create a single dataset + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + + # Should work without issues + interleaved = InterleavedDataset([ds1], seed=SEED) + + # Verify the hierarchical structure + assert interleaved.info.name == "interleaved_dataset" # default name + assert len(interleaved.info.children) == 1 + assert interleaved.info.children[0].name == "ds1" + assert interleaved.info.children[0].weight == 0.5 + + # Verify normalized weights sum to 1.0 (single dataset gets weight 1.0) + normalized_weights = interleaved._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 1 + assert abs(normalized_weights[0].item() - 1.0) < 1e-6 + + # Test that iteration works correctly + samples = list(islice(iter(interleaved), 10)) + assert len(samples) == 10 + + # All samples should come from the single dataset + sample_ids = {sample["id"] for sample in samples} + expected_ids = set(range(23)) # ds1 has IDs 0-22 + assert sample_ids == expected_ids + def test_sampling_ratios( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that datasets are sampled according to their assigned weights.""" - # Create two datasets with distinct ID ranges - # ds1 has IDs 0-22 (small dataset) - # ds2 has IDs 100-134 (medium dataset with offset) - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) + """Tests that datasets are sampled according to their assigned weights in nested structure.""" + # Create three datasets with distinct ID ranges + # ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - # Test with 70/30 weighting - interleaved = InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") - # Collect 300 samples - sample_count = 300 - samples = list(islice(iter(interleaved), sample_count)) + # Collect 400 samples + sample_count = 400 + samples = list(islice(iter(parent_interleaved), sample_count)) # Count samples by checking ID ranges - # ds1 has IDs < 100, ds2 has IDs >= 100 - ds1_count = sum(1 for s in samples if s["id"] < 100) - ds2_count = sum(1 for s in samples if s["id"] >= 100) + ds1_count = sum(1 for s in samples if 0 <= s["id"] < SMALL_DATASET_SIZE) + ds2_count = sum(1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100)) + ds3_count = sum(1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000)) - assert ds1_count + ds2_count == sample_count + assert ds1_count + ds2_count + ds3_count == sample_count - # Check ratios are approximately correct + # Calculate ratios ds1_ratio = ds1_count / sample_count ds2_ratio = ds2_count / sample_count + ds3_ratio = ds3_count / sample_count + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" - assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_metrics_aggregation( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that metrics from all child datasets are collected and aggregated.""" + """Tests that metrics from all child datasets are collected and aggregated in nested structure.""" ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.8) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - interleaved = InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + aggregator = MetricsAggregator() # Process some samples - total_samples = 200 - for sample in islice(iter(interleaved), total_samples): + total_samples = 300 + for sample in islice(iter(parent_interleaved), total_samples): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging(prefix="train") - # Should have metrics from both datasets, with flat keys + # Should have metrics from all three datasets, with flat keys assert "train_ds1/samples_seen" in metrics assert "train_ds2/samples_seen" in metrics + assert "train_ds3/samples_seen" in metrics - # Both datasets should have contributed samples + # All datasets should have contributed samples assert metrics["train_ds1/samples_seen"] > 0 assert metrics["train_ds2/samples_seen"] > 0 + assert metrics["train_ds3/samples_seen"] > 0 # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] + metrics["train_ds2/samples_seen"] + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] ) assert calculated_total_samples == total_samples - # Test that ratio is approximately correct + # Test that ratios are approximately correct based on nested weighting ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples + ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.2=0.1, ds2=0.5*0.8=0.4, ds3=0.5 + expected_ds1_ratio = 0.1 + expected_ds2_ratio = 0.4 + expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" - assert abs(ds2_ratio - 0.8) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.8" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_checkpointing( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that interleaved dataset checkpointing preserves sampling state.""" + """Tests that interleaved dataset checkpointing preserves sampling state in nested structure.""" def create_interleaved(): - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) - return InterleavedDataset([ds1, ds2], seed=SEED) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") # Original run interleaved1 = create_interleaved() @@ -258,16 +345,27 @@ def create_interleaved(): assert len(sampling_log) > 0, "Sampling log should not be empty" assert iteration_count > 0, "Iteration count should be positive" - # Check sampling ratios match expected weights (70/30) + # Check sampling ratios match expected weights for nested structure ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") - total_samples = ds1_count + ds2_count + ds3_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds3") + total_samples = ds1_count + ds2_count + ds3_count ds1_ratio = ds1_count / total_samples ds2_ratio = ds2_count / total_samples + ds3_ratio = ds3_count / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 - assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" - assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" class TestDistributedInterleavedDataset(FSDPTest): @property @@ -277,10 +375,10 @@ def world_size(self) -> int: @gpu_test(gpu_count=2) def test_distributed_interleaved_checkpointing(self): """ - Test interleaved dataset checkpointing with distributed settings. + Test interleaved dataset checkpointing with distributed settings using nested structure. Assertions: - Each rank processes non-overlapping data shards - - Sampling ratios (70/30) are maintained across ranks + - Sampling ratios for nested structure (ds1: 15%, ds2: 35%, ds3: 50%) are maintained across ranks - Checkpoint/resume produces identical batches (deterministic) - Metrics correctly aggregate across ranks """ @@ -303,6 +401,7 @@ def test_distributed_interleaved_checkpointing(self): def create_dataset(): file1 = tmp_path / "ds1.json" file2 = tmp_path / "ds2.json" + file3 = tmp_path / "ds3.json" # Only rank 0 creates the data files if rank == 0: @@ -310,6 +409,9 @@ def create_dataset(): create_test_json_file( file2, MEDIUM_DATASET_SIZE, offset=100 ) # IDs 100-134 + create_test_json_file( + file3, LARGE_DATASET_SIZE, offset=1000 + ) # IDs 1000-1046 dist.barrier() # Wait for file creation ds1 = HfIterableDataset( @@ -320,7 +422,7 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, - weight=0.8, + weight=0.3, ) ds2 = HfIterableDataset( path="json", @@ -330,11 +432,22 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, - weight=0.2, + weight=0.7, + ) + ds3 = HfIterableDataset( + path="json", + data_files=str(file3), + split="train", + dataset_name="ds3", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + weight=1.0, ) - # Create interleaved dataset with 80/20 weighting - return InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") def create_dataloader(dataset): loader = StatefulDataLoader( @@ -371,24 +484,35 @@ def create_dataloader(dataset): result["final_metrics"] == result["resumed_metrics"] ), "Final metrics don't match resumed metrics - aggregator state issue" - # Verify sampling ratio is approximately maintained (80/20 split) + # Verify sampling ratio is approximately maintained for nested structure all_ids = [] for batch in ( result["pre_checkpoint_batches"] + result["post_checkpoint_batches"] ): all_ids.extend(batch["id"].tolist()) - # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 - ds1_samples = sum(1 for id in all_ids if id < 100) - ds2_samples = sum(1 for id in all_ids if id >= 100) - total_samples = ds1_samples + ds2_samples + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_samples = sum(1 for id in all_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_samples = sum(1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100)) + ds3_samples = sum(1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000)) + total_samples = ds1_samples + ds2_samples + ds3_samples if total_samples > 0: ds1_ratio = ds1_samples / total_samples - assert 0.6 < ds1_ratio < 1.0, ( - f"Rank {rank}: Dataset sampling ratio {ds1_ratio:.2f} outside expected " - f"range for 80/20 split. Got {ds1_samples}, {ds2_samples} samples." - ) + ds2_ratio = ds2_samples / total_samples + ds3_ratio = ds3_samples / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 + + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" finally: # Clean up temp directory (only rank 0) diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index f5ecbb95ea..e0afccdde4 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -18,7 +18,7 @@ from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._interleaved import InterleavedDataset -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset, TuneIterableDataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset @@ -38,6 +38,7 @@ "chat_dataset", "cnn_dailymail_articles_dataset", "ConcatDataset", + "DatasetInfo", "grammar_dataset", "hh_rlhf_helpful_dataset", "HfIterableDataset", @@ -55,6 +56,7 @@ "stack_exchange_paired_dataset", "text_completion_dataset", "TextCompletionDataset", + "InfiniteTuneIterableDataset", "TuneIterableDataset", "wikitext_dataset", ] diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 3949e93508..8856f04699 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -17,12 +17,12 @@ DefaultTrainingMetricTransform, Metric, ) -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset logger = logging.getLogger(__name__) -class HfIterableDataset(TuneIterableDataset): +class HfIterableDataset(InfiniteTuneIterableDataset): """HuggingFace dataset implementation with composable metrics. This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. @@ -46,6 +46,7 @@ class HfIterableDataset(TuneIterableDataset): of world_size * dataloader_workers. dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated from the path, source, and split. + weight (Optional[float]): Weight for this dataset. Defaults to 1.0. filter_fn (Optional[Callable]): Filter function to apply to the dataset. filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. @@ -74,12 +75,12 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight # TODO: make it a property? + self._weight = weight # Create default transform if not provided self._metric_transform = metric_transform or DefaultTrainingMetricTransform() - # Auto-generate dataset name if not provided, ensuring it's always a string. + # Auto-generate dataset name if not provided, ensuring it's always a string if dataset_name is None: path = load_dataset_kwargs.get("path", None) source = load_dataset_kwargs.get("source", None) @@ -88,13 +89,14 @@ def __init__( for item in [path, source, split]: if item is not None: name_parts.append(str(item).replace("/", "_")) - self._dataset_name: str = "_".join(name_parts) - else: - self._dataset_name: str = dataset_name + dataset_name = "_".join(name_parts) + + # Build the hierarchical info object for this dataset + self._info = DatasetInfo(name=dataset_name, weight=weight) # Set dataset name on the transform if it supports it if hasattr(self._metric_transform, "set_dataset_name"): - self._metric_transform.set_dataset_name(self._dataset_name) + self._metric_transform.set_dataset_name(dataset_name) # Internal state for resumption self._num_epochs = 0 @@ -105,12 +107,9 @@ def __init__( ) @property - def dataset_name(self) -> str: - return self._dataset_name - - @property - def sampling_weight(self) -> float: - return self._weight + def info(self) -> DatasetInfo: + """Returns info for this leaf dataset, which has no children.""" + return self._info def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" @@ -227,7 +226,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # especially useful when interleaving multiple datasets, but # also necessary to track dataset-level metrics. metric_num_epochs = Metric( - dataset_name=self.dataset_name, + dataset_name=self.info.name, name="num_epochs", value=self._num_epochs, agg_type=AggregationType.MAX, @@ -243,14 +242,14 @@ def __iter__(self) -> Iterator[dict[str, Any]]: pass # Iterator is exhausted, which is expected. except Exception as e: logger.warning( - f"Dataset {self.dataset_name} encountered an unexpected error: {e}." + f"Dataset {self.info.name} encountered an unexpected error: {e}." ) raise # Check if we got zero samples - this might indicate an issue if samples_yielded == 0: logger.warning( - f"Dataset {self.dataset_name} epoch {self._num_epochs} yielded 0 samples - potential issue!" + f"Dataset {self.info.name} epoch {self._num_epochs} yielded 0 samples - potential issue!" ) # Epoch complete - increment and continue infinite loop diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 46c4cfb7dd..2dee36e557 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -4,137 +4,111 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections +from collections import deque import logging import math -from collections import deque from typing import Any, Iterator import torch -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import ( + DatasetInfo, + InfiniteTuneIterableDataset, +) logger = logging.getLogger(__name__) -class InterleavedDataset(TuneIterableDataset): +class InterleavedDataset(InfiniteTuneIterableDataset): """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. - - The weights are extracted from each dataset's sampling_weight property and normalized to sum to 1.0. + - The weights are extracted from each dataset's info.weight property and normalized to sum to 1.0. - This dataset is responsible for managing the state of its child datasets to ensure correct checkpointing and resumption. Args: - datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. + datasets (list[InfiniteTuneIterableDataset]): list of datasets to interleave. seed (int): Seed for sampling. - dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + weight (float): Weight for this dataset. Defaults to 1.0. + dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". sampling_log_maxlen (int): Maximum length of the sampling log. Raises: - ValueError: If duplicate dataset names are detected in the provided datasets. + ValueError: If duplicate dataset names are detected in the hierarchy. """ def __init__( self, - datasets: list[TuneIterableDataset], + datasets: list[InfiniteTuneIterableDataset], seed: int, + weight: float = 1.0, dataset_name: str = "interleaved_dataset", sampling_log_maxlen: int = 10000, ): - self._dataset_name = dataset_name + self._datasets = sorted(datasets, key=lambda ds: ds.info.name) self._sampling_log_maxlen = sampling_log_maxlen - # Preserve original order for weighted sampling - self._dataset_names = [ds.dataset_name for ds in datasets] - - # Create a name-to-dataset mapping for robust state management - self._datasets: dict[str, TuneIterableDataset] = { - ds.dataset_name: ds for ds in datasets - } - - # Validate unique dataset names upfront - fail fast with clear error - names = self._dataset_names - if len(names) != len(set(names)): - duplicates = [ - name for name, count in collections.Counter(names).items() if count > 1 - ] - raise ValueError( - f"Duplicate dataset names detected: {duplicates}. All {names=}" - f"Please provide a unique 'dataset_name' for each dataset in the interleaved list." - ) - - self._sampling_generator = torch.Generator().manual_seed(seed) + # Build the hierarchical info object for this dataset + self._info = DatasetInfo( + name=dataset_name, + weight=weight, + children=tuple(ds.info for ds in self._datasets), + ) - # Track sampling decisions for debugging and analysis - self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) - self._iteration_count = 0 + # Validate the entire hierarchy using the base class method + self._validate_unique_dataset_names() - # Extract weights from child datasets - weights = [] - for ds in datasets: - weight = ds.sampling_weight - if isinstance(weight, dict): - raise ValueError( - f"Child dataset '{ds.dataset_name}' returned a dict for sampling_weight, " - f"indicating it's a composite dataset (likely InterleavedDataset). " - f"Nested interleaving is not supported. Please flatten the dataset hierarchy." - ) - weights.append(weight) - - # Normalize weights to sum to 1 - total_weight = sum(weights) - self._weights = torch.tensor( - [w / total_weight for w in weights], dtype=torch.float - ) + # Extract weights from direct children and normalize them + child_weights = [info.weight for info in self._info.children] + total_weight = sum(child_weights) if not math.isclose(total_weight, 1.0, rel_tol=1e-9): logger.warning( f"Interleaved dataset normalized weights to sum to 1.0. " - f"Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + f"Previous weights={child_weights}, " + f"new weights={[w / total_weight for w in child_weights]}" ) + self._normalized_weights = torch.tensor( + [w / total_weight for w in child_weights], dtype=torch.float + ) + + # Track sampling decisions for debugging and analysis + self._sampling_log: deque[tuple[int, str]] = deque( + maxlen=self._sampling_log_maxlen + ) + self._iteration_count = 0 + self._sampling_generator = torch.Generator().manual_seed(seed) @property - def dataset_name(self) -> str: - return self._dataset_name - - @property - def sampling_weight(self) -> dict[str, float]: - return {name: weight.item() for name, weight in zip(self._dataset_names, self._weights)} + def info(self) -> DatasetInfo: + return self._info def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" - child_iters = {name: iter(ds) for name, ds in self._datasets.items()} + # Create a dictionary of iterators for each child dataset + child_iters = {ds.info.name: iter(ds) for ds in self._datasets} while True: - # Sample which dataset to use + # Sample a child dataset based on the normalized weights ds_idx = torch.multinomial( - self._weights, 1, replacement=True, generator=self._sampling_generator + self._normalized_weights, + 1, + replacement=True, + generator=self._sampling_generator, ).item() - # Sample an index, then get the name for safe lookup - ds_name = self._dataset_names[ds_idx] + selected_ds = self._datasets[ds_idx] + ds_name = selected_ds.info.name - # Log this sampling decision + # Log self._sampling_log.append((self._iteration_count, ds_name)) self._iteration_count += 1 - try: - sample = next(child_iters[ds_name]) - yield sample - except StopIteration: - # Per the design, child datasets must be infinite. - # We re-initialize to allow for continuous operation but warn loudly - # as this may indicate a design problem in the child dataset. - logger.warning( - f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " - "This is unexpected for an infinite dataset. Re-initializing its iterator." - ) - child_iters[ds_name] = iter(self._datasets[ds_name]) - sample = next(child_iters[ds_name]) - yield sample + # Yield the next sample from the selected child iterator + yield next(child_iters[ds_name]) def state_dict(self) -> dict[str, Any]: """Save state for the interleaver and its children.""" - # The parent is responsible for namespacing the child states. - child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} + # The parent is responsible for namespacing the child states + child_states = {ds.info.name: ds.state_dict() for ds in self._datasets} return { "sampling_generator_state": self._sampling_generator.get_state(), "child_states": child_states, @@ -146,11 +120,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state for the interleaver and its children.""" self._sampling_generator.set_state(state_dict["sampling_generator_state"]) child_states = state_dict["child_states"] - for name, ds in self._datasets.items(): - if name in child_states: - # Pass the raw state dict to the child - ds.load_state_dict(child_states[name]) - + + for ds in self._datasets: + ds.load_state_dict(child_states[ds.info.name]) + # Load sampling log and iteration count self._sampling_log = deque( state_dict.get("sampling_log", []), maxlen=self._sampling_log_maxlen diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 92290d658e..51dec07990 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -5,11 +5,23 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Iterator from torch.utils.data import IterableDataset +@dataclass(frozen=True) +class DatasetInfo: + """Represents hierarchical information about a dataset, including its name, + sampling weight and children. Children is a common case when composing datasets, + e.g. Packed(InterleavedDataset([ds1, ds2])). + """ + name: str + weight: float = 1.0 + children: tuple["DatasetInfo", ...] = field(default_factory=tuple) + + class TuneIterableDataset(IterableDataset, ABC): """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset @@ -19,22 +31,32 @@ class TuneIterableDataset(IterableDataset, ABC): @property @abstractmethod - def dataset_name(self) -> str: - """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" + def info(self) -> DatasetInfo: + """Returns a hierarchical structure of all dataset information, including + this dataset and its children.""" pass - @property - def sampling_weight(self) -> float | dict[str, float]: - """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. + def _validate_unique_dataset_names(self) -> None: + """Traverses the DatasetInfo tree and raises ValueError on duplicate names.""" + root_info = self.info + names = [] + to_process = [root_info] + + while to_process: + node = to_process.pop(0) + names.append(node.name) + to_process.extend(node.children) - For leaf datasets: returns a float representing the relative weight. - For composite datasets: returns a dict mapping child dataset names to their weights. - """ - return 1.0 + # Check for duplicates after traversing the whole tree + duplicates = [name for name in set(names) if names.count(name) > 1] + if duplicates: + raise ValueError( + f"Duplicate dataset names found in hierarchy: {duplicates=}, all names={names}" + ) @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: - """Returns an infinite iterator over the dataset. Each implementation is responsible + """Returns an iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling, distribution of data across ranks, and making it an infinite stream.""" pass @@ -48,3 +70,8 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" pass + +class InfiniteTuneIterableDataset(TuneIterableDataset): + """Abstract base class for infinite datasets, which yield samples indefinitely. + It only purpose is to make it explicit that the dataset is expected to be infinite.""" + pass \ No newline at end of file From f9a1aecee41abf4e3fe2f7c399442885e30a9920 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:13:40 -0400 Subject: [PATCH 33/48] nits hf_iterable --- torchtune/datasets/_hf_iterable.py | 29 ++++++++++++++++------------ torchtune/datasets/_iterable_base.py | 5 ++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 8856f04699..5c759039d5 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -145,13 +145,22 @@ def _setup_hf_dataset( # Load and shard dataset ds = load_dataset(**load_dataset_kwargs) - # Use to_iterable_dataset for streaming datasets - if not load_dataset_kwargs.get("streaming", False): - + # Use to_iterable_dataset for non-streaming datasets + is_streaming = load_dataset_kwargs.get("streaming", False) + if is_streaming: + logger.warning( + f"Streaming datasets were not yet tested for distributed training. " + f"split_dataset_by_node is applied, but no resharding was done manually. " + f"Dataset '{self.info.name}' has " + f"{getattr(ds, 'num_shards', 'unknown')}, and your training has {world_size} ranks." + f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard" + f"Consider setting streaming=False, which should also be faster." + ) + if not is_streaming: # Define number of shards based on (world_size, num of shards per GPU, dataloader workers) # E.g. world_size=2, num_shards_per_rank=16, dataloader_workers=3 - # we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. - # Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. + # we will try 2*16 = 32 shards. Since 32 is not a multiple of 6, we will do 36 shards. + # Each rank gets 18 shards, each dataloader worker in that rank gets 6 shards. worker_info = torch.utils.data.get_worker_info() num_dataloader_workers = worker_info.num_workers if worker_info else 1 @@ -171,14 +180,12 @@ def _setup_hf_dataset( # If the dataset is not streaming and has a defined length, # we cannot have num_shards > dataset_size. - if not load_dataset_kwargs.get("streaming", False) and hasattr( - ds, "__len__" - ): + if hasattr(ds, "__len__"): dataset_size = len(ds) if num_shards > dataset_size: raise ValueError( f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." - f"Please decrease num_shards_per_rank." + f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}." ) ds = ds.to_iterable_dataset(num_shards=num_shards) @@ -210,8 +217,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: """ while True: # Infinite iteration - epoch_seed = self._seed + self._num_epochs - self._ds.set_epoch(epoch_seed) + self._ds.set_epoch(self._num_epochs) epoch_iterator = iter(self._ds) samples_yielded = 0 @@ -262,7 +268,6 @@ def state_dict(self) -> dict[str, Any]: hf_state = self._ds.state_dict() state = { "num_epochs": self._num_epochs, - "seed": self._seed, "hf_dataset_state": hf_state, } return state diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 51dec07990..fee09a5123 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -71,7 +71,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" pass + class InfiniteTuneIterableDataset(TuneIterableDataset): """Abstract base class for infinite datasets, which yield samples indefinitely. - It only purpose is to make it explicit that the dataset is expected to be infinite.""" + It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. + it never exhausts. This is helpful to avoid complexity due to some rank hanging because + of lack of data""" pass \ No newline at end of file From f7a3aa76453af770bbb48c30e7092816a7059054 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:14:40 -0400 Subject: [PATCH 34/48] update readme --- torchtune/data/metrics/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md index 6c6c413246..bf6cb8d27b 100644 --- a/torchtune/data/metrics/readme.md +++ b/torchtune/data/metrics/readme.md @@ -24,7 +24,7 @@ The metrics module provides a robust system for tracking and aggregating trainin │ • Uses pluggable AggregationHandlers │ │ • Handles distributed reduction │ └─────────────────────┬──────────────────────────────┘ - │ {prefix_dataset/metric: value} + │ {prefix}_{dataset_name}/{metric_name} # prefix is "train", "val", etc. ┌─────────────────────▼──────────────────────────────┐ │ Logging System │ │ • W&B, TensorBoard, etc. │ From 17878bf9f498434c9af7469b043a729ef8d9a9af Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:23:31 -0400 Subject: [PATCH 35/48] make metric dataset name explicit --- torchtune/data/metrics/_metric_transform.py | 31 +++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 529fff8c5c..5affaca046 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -6,8 +6,7 @@ from dataclasses import dataclass from enum import Enum -from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union from torchtune.modules.transforms import Transform @@ -41,7 +40,6 @@ class MetricTransform(Transform): def __init__(self): # dataset_name is set by the dataset using set_dataset_name self.dataset_name: Optional[str] = None - self.new_metric: Optional[Callable] = None def set_dataset_name(self, dataset_name: str) -> None: """Called by dataset to set the namespace for metrics. @@ -53,8 +51,6 @@ def set_dataset_name(self, dataset_name: str) -> None: dataset_name (str): Name of the dataset for metric namespacing """ self.dataset_name = dataset_name - # Create a partial to make it easier to create new metrics - self.new_metric = partial(Metric, dataset_name=dataset_name) def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. @@ -72,7 +68,7 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transform to sample, adding generated metrics.""" - if self.dataset_name is None or self.new_metric is None: + if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." ) @@ -112,7 +108,7 @@ class DefaultTrainingMetricTransform(MetricTransform): """ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - if self.new_metric is None: + if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." ) @@ -123,11 +119,22 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: # Create metrics for this sample return [ - self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), - self.new_metric( - name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + Metric( + dataset_name=self.dataset_name, + name="samples_seen", + value=1, + agg_type=AggregationType.SUM, ), - self.new_metric( - name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + Metric( + dataset_name=self.dataset_name, + name="tokens_seen", + value=token_len, + agg_type=AggregationType.SUM, + ), + Metric( + dataset_name=self.dataset_name, + name="seq_len", + value=token_len, + agg_type=AggregationType.DISTRIBUTION, ), ] From 101e96e205db8ca56cb72c654a1aef316a2e3758 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:29:07 -0400 Subject: [PATCH 36/48] update recipe to share log freq + validagtion msg --- recipes/full_finetune_distributed.py | 56 +++++++++++++--------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fadb3f7c23..95d5069eaa 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -277,7 +277,6 @@ def __init__(self, cfg: DictConfig) -> None: # Step-based training support self.num_training_steps = cfg.num_training_steps - self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) self._metrics_aggregator = None # Will be initialized in setup def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: @@ -311,7 +310,7 @@ def setup(self, cfg: DictConfig) -> None: """ if cfg.get("dataset_val") is not None: raise NotImplementedError( - "Validation is not supported yet with iterable datasets." + "Validation is not supported yet with iterable datasets since it currently requiresinfinite datasets." ) if self.fsdp_cpu_offload: @@ -1045,39 +1044,34 @@ def train(self) -> None: ) # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - "tokens_per_second_per_gpu": ( - num_tokens / self.parallel_dims.non_data_parallel_size - ) - / (time_per_step * self.world_size), - } - if self._log_peak_memory_stats: - log_dict.update(training.get_memory_stats(device=self._device)) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict(log_dict, step=self.global_step) - - # Log dataset metrics - # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if self.global_step % self._dataset_metrics_log_freq == 0: + if self.global_step % self._log_every_n_steps == 0: + # Get dataset metrics outside of rank zero check since it involves all_gather dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) + if self._is_rank_zero: - self._metric_logger.log_dict( - dataset_metrics, step=self.global_step - ) + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + "tokens_per_second_per_gpu": ( + num_tokens / self.parallel_dims.non_data_parallel_size + ) + / (time_per_step * self.world_size), + } + if dataset_metrics: + log_dict.update(dataset_metrics) + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict(log_dict, step=self.global_step) + # Save checkpoint if specified by user if ( From 1b3f3fcc4f37f0746c68d57cc9a99cd27a64daa1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 07:44:15 -0700 Subject: [PATCH 37/48] update interleaved tests to do nesting --- tests/torchtune/datasets/test_interleaved.py | 229 +++++++++++++------ 1 file changed, 158 insertions(+), 71 deletions(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index c2ea53c970..96c825d2f1 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -108,38 +108,46 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" - + # Test 1: Duplicate dataset names should raise an error ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) ds2 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - with pytest.raises(ValueError, match="Duplicate dataset names found in hierarchy"): + with pytest.raises( + ValueError, match="Duplicate dataset names found in hierarchy" + ): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) # Test 2: Nested interleaved datasets should be supported ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=1.5) - interleaved_child = InterleavedDataset([ds1, ds3], seed=SEED, dataset_name="interleaved_child") - + interleaved_child = InterleavedDataset( + [ds1, ds3], seed=SEED, dataset_name="interleaved_child" + ) + # Create a parent interleaved dataset containing the nested one ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=0.5) - + # Test 3: Weight normalization should work with a warning with patch("logging.Logger.warning") as mock_warning: - interleaved_parent = InterleavedDataset([interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent") - + interleaved_parent = InterleavedDataset( + [interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent" + ) + # Verify that a warning was logged about weight normalization mock_warning.assert_called_once() warning_message = mock_warning.call_args[0][0] assert "normalized" in warning_message.lower() - + # Verify the hierarchical structure is correct assert interleaved_parent.info.name == "interleaved_parent" assert len(interleaved_parent.info.children) == 2 - assert interleaved_parent.info.children[0].name == "interleaved_child" - assert interleaved_parent.info.children[1].name == "ds4" - + # Datasets are sorted alphabetically, so ds4 comes before interleaved_child + assert interleaved_parent.info.children[0].name == "ds4" + assert interleaved_parent.info.children[1].name == "interleaved_child" + # Verify the nested structure within the nested dataset - nested_info = interleaved_parent.info.children[0] + # interleaved_child is at index 1 due to alphabetical sorting + nested_info = interleaved_parent.info.children[1] assert len(nested_info.children) == 2 assert nested_info.children[0].name == "ds1" assert nested_info.children[1].name == "ds3" @@ -149,49 +157,54 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): normalized_weights = interleaved_parent._normalized_weights assert isinstance(normalized_weights, torch.Tensor) assert len(normalized_weights) == 2 - - # nested: 2.0/2.5 = 0.8, ds4: 0.5/2.5 = 0.2 - assert abs(normalized_weights[0].item() - 0.8) < 1e-6 - assert abs(normalized_weights[1].item() - 0.2) < 1e-6 + + # ds4: 0.5/(0.5+1.0) = 1/3, interleaved_child: 1.0/(0.5+1.0) = 2/3 + assert abs(normalized_weights[0].item() - 1 / 3) < 1e-3 + assert abs(normalized_weights[1].item() - 2 / 3) < 1e-3 assert abs(normalized_weights.sum().item() - 1.0) < 1e-6 # Verify that original weights in info remain unnormalized child_weights = [child.weight for child in interleaved_parent.info.children] - assert abs(child_weights[0] - 2.0) < 1e-6 # nested original weight - assert abs(child_weights[1] - 0.5) < 1e-6 # ds4 original weight - + assert abs(child_weights[0] - 0.5) < 1e-6 # ds4 original weight + assert ( + abs(child_weights[1] - 1.0) < 1e-6 + ) # interleaved_child original weight def test_single_dataset(self, dataset_factory, small_dataset_file): """Tests that InterleavedDataset works correctly with a single dataset.""" # Create a single dataset ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - + # Should work without issues interleaved = InterleavedDataset([ds1], seed=SEED) - + # Verify the hierarchical structure assert interleaved.info.name == "interleaved_dataset" # default name assert len(interleaved.info.children) == 1 assert interleaved.info.children[0].name == "ds1" assert interleaved.info.children[0].weight == 0.5 - + # Verify normalized weights sum to 1.0 (single dataset gets weight 1.0) normalized_weights = interleaved._normalized_weights assert isinstance(normalized_weights, torch.Tensor) assert len(normalized_weights) == 1 assert abs(normalized_weights[0].item() - 1.0) < 1e-6 - + # Test that iteration works correctly samples = list(islice(iter(interleaved), 10)) assert len(samples) == 10 - - # All samples should come from the single dataset + + # All samples should come from the single dataset (ds1 has IDs 0-22) sample_ids = {sample["id"] for sample in samples} - expected_ids = set(range(23)) # ds1 has IDs 0-22 + expected_ids = set(range(10)) # ds1 has IDs 0-22 assert sample_ids == expected_ids - + def test_sampling_ratios( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that datasets are sampled according to their assigned weights in nested structure.""" # Create three datasets with distinct ID ranges @@ -201,8 +214,12 @@ def test_sampling_ratios( ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) # Collect 400 samples sample_count = 400 @@ -210,8 +227,12 @@ def test_sampling_ratios( # Count samples by checking ID ranges ds1_count = sum(1 for s in samples if 0 <= s["id"] < SMALL_DATASET_SIZE) - ds2_count = sum(1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100)) - ds3_count = sum(1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000)) + ds2_count = sum( + 1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000) + ) assert ds1_count + ds2_count + ds3_count == sample_count @@ -219,7 +240,7 @@ def test_sampling_ratios( ds1_ratio = ds1_count / sample_count ds2_ratio = ds2_count / sample_count ds3_ratio = ds3_count / sample_count - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -229,12 +250,22 @@ def test_sampling_ratios( expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_metrics_aggregation( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that metrics from all child datasets are collected and aggregated in nested structure.""" ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) @@ -242,9 +273,13 @@ def test_metrics_aggregation( ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") - + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + aggregator = MetricsAggregator() # Process some samples @@ -266,9 +301,9 @@ def test_metrics_aggregation( # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] + - metrics["train_ds2/samples_seen"] + - metrics["train_ds3/samples_seen"] + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] ) assert calculated_total_samples == total_samples @@ -286,12 +321,22 @@ def test_metrics_aggregation( expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_checkpointing( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that interleaved dataset checkpointing preserves sampling state in nested structure.""" @@ -299,10 +344,14 @@ def create_interleaved(): ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) # Original run interleaved1 = create_interleaved() @@ -341,20 +390,36 @@ def create_interleaved(): state_dict = interleaved1.state_dict() sampling_log = state_dict["sampling_log"] iteration_count = state_dict["iteration_count"] - + assert len(sampling_log) > 0, "Sampling log should not be empty" assert iteration_count > 0, "Iteration count should be positive" - - # Check sampling ratios match expected weights for nested structure - ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") - ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") - ds3_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds3") + + # Check sampling ratios by analyzing the actual samples processed during the test + # Since the sampling log only shows immediate children ("child", "ds3"), + # we need to look at the actual sample IDs to determine leaf dataset usage + + # Collect all sample IDs from the batches processed during checkpointing + all_sample_ids = [] + for batch_list in [ + result["pre_checkpoint_batches"], + result["post_checkpoint_batches"], + ]: + for batch in batch_list: + all_sample_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_count = sum(1 for id in all_sample_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_count = sum( + 1 for id in all_sample_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for id in all_sample_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) total_samples = ds1_count + ds2_count + ds3_count - ds1_ratio = ds1_count / total_samples ds2_ratio = ds2_count / total_samples ds3_ratio = ds3_count / total_samples - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -362,10 +427,18 @@ def create_interleaved(): expected_ds1_ratio = 0.15 expected_ds2_ratio = 0.35 expected_ds3_ratio = 0.5 - - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + # Allow larger tolerance due to small sample size in checkpointing test + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.2 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.2 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.2 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + class TestDistributedInterleavedDataset(FSDPTest): @property @@ -446,8 +519,12 @@ def create_dataset(): ) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) def create_dataloader(dataset): loader = StatefulDataLoader( @@ -493,15 +570,19 @@ def create_dataloader(dataset): # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 ds1_samples = sum(1 for id in all_ids if 0 <= id < SMALL_DATASET_SIZE) - ds2_samples = sum(1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100)) - ds3_samples = sum(1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000)) + ds2_samples = sum( + 1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_samples = sum( + 1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) total_samples = ds1_samples + ds2_samples + ds3_samples if total_samples > 0: ds1_ratio = ds1_samples / total_samples ds2_ratio = ds2_samples / total_samples ds3_ratio = ds3_samples / total_samples - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -509,10 +590,16 @@ def create_dataloader(dataset): expected_ds1_ratio = 0.15 expected_ds2_ratio = 0.35 expected_ds3_ratio = 0.5 - - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" finally: # Clean up temp directory (only rank 0) From fac3fd5596964acf5fce51248166d70f78f739bc Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 11:27:49 -0700 Subject: [PATCH 38/48] lint --- recipes/full_finetune_distributed.py | 7 ++++--- tests/torchtune/datasets/test_interleaved.py | 2 +- torchtune/datasets/__init__.py | 6 +++++- torchtune/datasets/_hf_iterable.py | 9 +++++---- torchtune/datasets/_interleaved.py | 14 ++++---------- torchtune/datasets/_iterable_base.py | 4 +++- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 95d5069eaa..444596619b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -1049,7 +1049,7 @@ def train(self) -> None: dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) - + if self._is_rank_zero: time_per_step = time.perf_counter() - t0 log_dict = { @@ -1067,12 +1067,13 @@ def train(self) -> None: if dataset_metrics: log_dict.update(dataset_metrics) if self._log_peak_memory_stats: - log_dict.update(training.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict(log_dict, step=self.global_step) - # Save checkpoint if specified by user if ( self.save_every_n_steps is not None diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 96c825d2f1..db4ec95035 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -464,7 +464,7 @@ def test_distributed_interleaved_checkpointing(self): temp_dir = None # Broadcast temp directory to all ranks - temp_dir_list = [temp_dir] + temp_dir_list = [temp_dir] if temp_dir is not None else [""] dist.broadcast_object_list(temp_dir_list, src=0) temp_dir = temp_dir_list[0] tmp_path = Path(temp_dir) diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index e0afccdde4..b38663578e 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -18,7 +18,11 @@ from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._interleaved import InterleavedDataset -from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset, TuneIterableDataset +from torchtune.datasets._iterable_base import ( + DatasetInfo, + InfiniteTuneIterableDataset, + TuneIterableDataset, +) from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 5c759039d5..7b6f790914 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -16,6 +16,7 @@ AggregationType, DefaultTrainingMetricTransform, Metric, + MetricTransform, ) from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset @@ -38,7 +39,7 @@ class HfIterableDataset(InfiniteTuneIterableDataset): model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. - metric_transform (Optional[Callable]): Takes the sample and computes metrics, e.g. token count. + metric_transform (Optional[MetricTransform]): Takes the sample and computes metrics, e.g. token count. If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. seed (int): Seed for shuffling. @@ -59,7 +60,7 @@ def __init__( message_transform: Optional[Callable] = None, model_transform: Optional[Callable] = None, output_transform: Optional[Callable] = None, - metric_transform: Optional[Callable] = None, + metric_transform: Optional[MetricTransform] = None, shuffle_buffer_size: Optional[int] = 1000, weight: Optional[float] = 1.0, seed: int = 42, @@ -75,7 +76,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight + self._weight = weight if weight is not None else 1.0 # Create default transform if not provided self._metric_transform = metric_transform or DefaultTrainingMetricTransform() @@ -92,7 +93,7 @@ def __init__( dataset_name = "_".join(name_parts) # Build the hierarchical info object for this dataset - self._info = DatasetInfo(name=dataset_name, weight=weight) + self._info = DatasetInfo(name=dataset_name, weight=self._weight) # Set dataset name on the transform if it supports it if hasattr(self._metric_transform, "set_dataset_name"): diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 2dee36e557..2267696ef4 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -4,17 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections import deque import logging import math +from collections import deque from typing import Any, Iterator import torch -from torchtune.datasets._iterable_base import ( - DatasetInfo, - InfiniteTuneIterableDataset, -) +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset logger = logging.getLogger(__name__) @@ -31,9 +28,6 @@ class InterleavedDataset(InfiniteTuneIterableDataset): weight (float): Weight for this dataset. Defaults to 1.0. dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". sampling_log_maxlen (int): Maximum length of the sampling log. - - Raises: - ValueError: If duplicate dataset names are detected in the hierarchy. """ def __init__( @@ -69,7 +63,7 @@ def __init__( self._normalized_weights = torch.tensor( [w / total_weight for w in child_weights], dtype=torch.float ) - + # Track sampling decisions for debugging and analysis self._sampling_log: deque[tuple[int, str]] = deque( maxlen=self._sampling_log_maxlen @@ -88,7 +82,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: while True: # Sample a child dataset based on the normalized weights - ds_idx = torch.multinomial( + ds_idx: int = torch.multinomial( self._normalized_weights, 1, replacement=True, diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index fee09a5123..a26c22eafa 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -17,6 +17,7 @@ class DatasetInfo: sampling weight and children. Children is a common case when composing datasets, e.g. Packed(InterleavedDataset([ds1, ds2])). """ + name: str weight: float = 1.0 children: tuple["DatasetInfo", ...] = field(default_factory=tuple) @@ -77,4 +78,5 @@ class InfiniteTuneIterableDataset(TuneIterableDataset): It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. it never exhausts. This is helpful to avoid complexity due to some rank hanging because of lack of data""" - pass \ No newline at end of file + + pass From 29ba1cb85e8e399fd79a3f4f4eff95e677d6b1f0 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 17:46:34 -0700 Subject: [PATCH 39/48] error if duplicated metric name --- .../torchtune/data/test_metrics_aggregator.py | 104 ++++++++++++++++-- .../torchtune/data/test_metrics_transform.py | 6 +- tests/torchtune/datasets/test_hf_iterable.py | 8 +- .../data/metrics/_metric_agg_handlers.py | 43 ++++---- torchtune/data/metrics/_metric_aggregator.py | 71 ++++++++++-- torchtune/data/metrics/_metric_transform.py | 14 +-- torchtune/datasets/_hf_iterable.py | 2 +- 7 files changed, 188 insertions(+), 60 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index b65c11f533..c2e8141ff9 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging + import pytest import torch.distributed as dist from tests.test_utils import gpu_test @@ -34,7 +36,9 @@ def test_aggregation_types(self, agg_type, test_values, expected): aggregator = MetricsAggregator() metrics = [ - Metric(dataset_name="test", name="metric", value=val, agg_type=agg_type) + Metric( + dataset_name="test", metric_name="metric", value=val, agg_type=agg_type + ) for val in test_values ] aggregator.update(metrics) @@ -43,7 +47,7 @@ def test_aggregation_types(self, agg_type, test_values, expected): if agg_type == AggregationType.CATEGORICAL_COUNT: for category, count in expected.items(): - assert result[f"train_test/metric_{category}_count"] == count + assert result[f"train_test/metric_count_{category}"] == count else: assert result["train_test/metric"] == expected @@ -61,10 +65,10 @@ def test_distribution_metrics(self): result = aggregator.get_metrics_for_logging(prefix="train") # Verify distribution statistics - assert result["train_test/dist_metric_mean"] == 5.5 - assert result["train_test/dist_metric_min"] == 1 - assert result["train_test/dist_metric_max"] == 10 - assert result["train_test/dist_metric_p50"] == 5.5 + assert result["train_test/dist_metric_stat_mean"] == 5.5 + assert result["train_test/dist_metric_stat_min"] == 1 + assert result["train_test/dist_metric_stat_max"] == 10 + assert result["train_test/dist_metric_stat_p50"] == 5.5 def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -149,6 +153,82 @@ def test_prefix_handling(self): assert result_no_prefix["data_test_ds/metric1"] == 42 assert result_no_prefix["data_test_ds/metric2"] == 84 + def test_metric_consistency_validation(self): + """Test that same metric name must use same aggregation type.""" + aggregator = MetricsAggregator() + + # First metric with SUM aggregation + metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)] + aggregator.update(metrics1) + + # Try to use same metric name with different aggregation type - should fail + metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)] + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.update(metrics2) + + # Same metric name with same aggregation type should work + metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)] + aggregator.update(metrics3) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_test/my_metric"] == 30 # 10 + 20 + + def test_metric_consistency_across_datasets(self): + """Test that same metric name can use different aggregation types across different datasets.""" + aggregator = MetricsAggregator() + + # Same metric name but different datasets - should be allowed + metrics = [ + Metric("dataset1", "metric", 10, AggregationType.SUM), + Metric("dataset2", "metric", 5.0, AggregationType.MEAN), + ] + aggregator.update(metrics) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_dataset1/metric"] == 10 + assert result["train_dataset2/metric"] == 5.0 + + def test_handler_generated_metric_validation(self): + """Test that handler-generated metrics are validated for consistency.""" + aggregator = MetricsAggregator() + + # Create a user-defined metric that will conflict with distribution stats + user_metrics = [ + Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM) + ] + aggregator.update(user_metrics) + + # Now try to add a distribution metric that will generate conflicting stat names + dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.DISTRIBUTION)] + aggregator.update(dist_metrics) + + # This should fail when trying to get metrics for logging because the handler + # will try to create "dist_metric_stat_mean" which conflicts with the user metric + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.get_metrics_for_logging(prefix="train") + + def test_handler_replacement_warning(self, caplog): + """Test that replacing handlers in use generates a warning.""" + aggregator = MetricsAggregator() + + # Add a metric that uses SUM aggregation + metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)] + aggregator.update(metrics) + + # Replace the SUM handler - should generate warning + from torchtune.data.metrics._metric_agg_handlers import SumAggHandler + + with caplog.at_level(logging.WARNING): + aggregator.register_handler(AggregationType.SUM, SumAggHandler()) + + # Check that the expected warning was logged + assert len(caplog.records) == 1 + assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + class TestDistributedMetricsAggregator(FSDPTest): """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" @@ -245,15 +325,15 @@ def test_distributed_all_aggregation_types(self): # DISTRIBUTION: Combined values [0,1,2,3,4,10,11,12,13,14] # Mean should be average of local means: (2 + 12) / 2 = 7 - assert result["train_test/dist_metric_mean"] == 7 - assert result["train_test/dist_metric_min"] == 0 - assert result["train_test/dist_metric_max"] == 14 + assert result["train_test/dist_metric_stat_mean"] == 7 + assert result["train_test/dist_metric_stat_min"] == 0 + assert result["train_test/dist_metric_stat_max"] == 14 # CATEGORICAL_COUNT: Total counts across ranks # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 - assert result["train_test/cat_metric_cat_A_count"] == 4 - assert result["train_test/cat_metric_cat_B_count"] == 2 - assert result["train_test/cat_metric_cat_C_count"] == 4 + assert result["train_test/cat_metric_count_cat_A"] == 4 + assert result["train_test/cat_metric_count_cat_B"] == 2 + assert result["train_test/cat_metric_count_cat_C"] == 4 @gpu_test(gpu_count=2) def test_distributed_state_dict_resumption(self): diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 8a4a86d7dd..eb7a1e951e 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -38,17 +38,17 @@ def test_basic_metrics_generation(self): # Check each metric for metric in metrics: - if metric.name == "samples_seen": + if metric.metric_name == "samples_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 1 assert metric.agg_type == AggregationType.SUM - elif metric.name == "tokens_seen": + elif metric.metric_name == "tokens_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.SUM - elif metric.name == "seq_len": + elif metric.metric_name == "seq_len": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.DISTRIBUTION diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index ec3ca26936..9a54a0cc99 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -239,7 +239,9 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in first_epoch_samples: first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value for metric in first_epoch_metrics if metric.name == "epoch" + metric.value + for metric in first_epoch_metrics + if metric.metric_name == "epoch" ] assert all( epoch_value == 0 for epoch_value in epoch_values @@ -250,7 +252,9 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in second_epoch_samples: second_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value for metric in second_epoch_metrics if metric.name == "epoch" + metric.value + for metric in second_epoch_metrics + if metric.metric_name == "epoch" ] assert all( epoch_value == 1 for epoch_value in epoch_values diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index c3415aba7d..d5c4122228 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections import Counter, deque from dataclasses import dataclass, field -from typing import Any, Union +from typing import Any import torch @@ -76,21 +76,18 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: pass @abstractmethod - def finalize_local_agg( - self, local_agg_metric: MetricState - ) -> Union[MetricState, list[MetricState]]: + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: """ Computes the final value from the locally aggregated state. - In a distributed setting, this is called before the reduction step. - This method can also expand a single metric into multiple, for instance, + This method may expand a single metric into multiple, for instance, a distribution into mean, min, max, and percentiles. Args: local_agg_metric (MetricState): The locally aggregated metric state to finalize. Returns: - A single `MetricState` or a list of them if the metric expands. + list[MetricState]: List of finalized metric states. """ pass @@ -156,8 +153,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value += metric.value - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: if not local_agg_metrics: @@ -193,8 +190,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value = max(local_agg_metric.value, metric.value) - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: max_value = max(r.value for r in local_agg_metrics) @@ -227,8 +224,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value = min(local_agg_metric.value, metric.value) - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: min_value = min(r.value for r in local_agg_metrics) @@ -259,12 +256,12 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["sum"] += metric.value local_agg_metric.metadata["count"] += 1 - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: count = local_agg_metric.metadata["count"] local_agg_metric.value = ( local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 ) - return local_agg_metric + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) @@ -349,42 +346,42 @@ def _compute_distribution_stats( metrics = [ MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_mean", + metric_name=f"{local_agg_metric.metric_name}_stat_mean", value=mean_val, agg_type=AggregationType.MEAN, metadata={"sum": sum_val, "count": n}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_min", + metric_name=f"{local_agg_metric.metric_name}_stat_min", value=min_val, agg_type=AggregationType.MIN, metadata={}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_max", + metric_name=f"{local_agg_metric.metric_name}_stat_max", value=max_val, agg_type=AggregationType.MAX, metadata={}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p05", + metric_name=f"{local_agg_metric.metric_name}_stat_p05", value=p05_val, agg_type=AggregationType.MEAN, metadata={"sum": p05_val, "count": 1}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p50", + metric_name=f"{local_agg_metric.metric_name}_stat_p50", value=p50_val, agg_type=AggregationType.MEAN, metadata={"sum": p50_val, "count": 1}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p95", + metric_name=f"{local_agg_metric.metric_name}_stat_p95", value=p95_val, agg_type=AggregationType.MEAN, metadata={"sum": p95_val, "count": 1}, @@ -396,7 +393,7 @@ def _compute_distribution_stats( metrics.append( MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_std", + metric_name=f"{local_agg_metric.metric_name}_stat_std", value=std_val, agg_type=AggregationType.MEAN, metadata={"sum": std_val, "count": 1}, @@ -452,7 +449,7 @@ def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState] results.append( MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_{category}_count", + metric_name=f"{local_agg_metric.metric_name}_count_{category}", value=count, agg_type=AggregationType.SUM, ) diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index 633d5c6b80..cb4f78abf0 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. import ast +import logging from collections import defaultdict -from typing import Any +from typing import Any, Union import torch.distributed as dist @@ -22,6 +23,8 @@ ) from torchtune.data.metrics._metric_transform import AggregationType, Metric +logger = logging.getLogger(__name__) + class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. @@ -77,6 +80,9 @@ def __init__(self, dist_window_size: int = 1000): self._metric_states: dict[tuple[str, str], MetricState] = {} self._dist_window_size = dist_window_size + # Track aggregation types for validation - prevents same metric name with different agg types + self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} + # Create handler registry - all handlers initialized upfront self._handlers: dict[AggregationType, AggregationHandler] = { AggregationType.SUM: SumAggHandler(), @@ -87,6 +93,24 @@ def __init__(self, dist_window_size: int = 1000): AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), } + def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: + """Validate that metric name uses consistent aggregation type.""" + metric_key = (metric.dataset_name, metric.metric_name) + metric_name = metric.metric_name + + if metric_key in self._metric_agg_types: + existing_agg_type = self._metric_agg_types[metric_key] + if existing_agg_type != metric.agg_type: + raise ValueError( + f"Metric '{metric_name}' in dataset '{metric.dataset_name}' " + f"is already registered with aggregation type {existing_agg_type.value}, " + f"but a handler or user code tried to use it with type {metric.agg_type.value}. " + f"Use different metric names for different aggregation types." + ) + else: + # Track this metric's aggregation type + self._metric_agg_types[metric_key] = metric.agg_type + def register_handler( self, agg_type: AggregationType, handler: AggregationHandler ) -> None: @@ -94,8 +118,17 @@ def register_handler( Args: agg_type (AggregationType): The aggregation type to handle - handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface + handler (AggregationHandler): Handler instance implementing the AggregationHandler interface """ + # Warn if replacing a handler that's already in use + if agg_type in self._handlers and any( + state.agg_type == agg_type for state in self._metric_states.values() + ): + logger.warning( + f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " + f"This may affect existing metric behavior." + ) + self._handlers[agg_type] = handler def update(self, metrics: list[Metric]) -> None: @@ -105,10 +138,14 @@ def update(self, metrics: list[Metric]) -> None: metrics (list[Metric]): List of metrics to update the state with Raises: - ValueError: If no handler is registered for a metric's aggregation type. + ValueError: If no handler is registered for a metric's aggregation type, + or if metric name conflicts with existing aggregation type. """ for metric in metrics: - metric_key = (metric.dataset_name, metric.name) + # Same metric name must use same aggregation type + self._validate_metric_consistency(metric) + + metric_key = (metric.dataset_name, metric.metric_name) handler = self._handlers.get(metric.agg_type) if handler is None: @@ -118,7 +155,7 @@ def update(self, metrics: list[Metric]) -> None: if metric_key not in self._metric_states: self._metric_states[metric_key] = handler.initialize_metric_state( - metric.dataset_name, metric.name, metric.agg_type + metric.dataset_name, metric.metric_name, metric.agg_type ) local_agg_metric = self._metric_states[metric_key] @@ -153,13 +190,13 @@ def _compute_unified_metrics(self) -> list[MetricState]: prepared_results = [] for local_agg_metric in self._metric_states.values(): handler = self._handlers[local_agg_metric.agg_type] - prepared = handler.finalize_local_agg(local_agg_metric) - if isinstance( - prepared, list - ): # Distribution/categorical expands to multiple - prepared_results.extend(prepared) - else: - prepared_results.append(prepared) + generated_metrics = handler.finalize_local_agg(local_agg_metric) + + # Validate each newly generated metric state immediately + for gen_metric in generated_metrics: + self._validate_metric_consistency(gen_metric) + + prepared_results.extend(generated_metrics) # Step 2: Apply distributed reduction if needed if dist.is_initialized() and dist.get_world_size() > 1: @@ -238,6 +275,10 @@ def state_dict(self) -> dict[str, Any]: "required_agg_types": list( required_agg_types ), # Save which handlers are needed + # Save which aggregation types are used for each metric + "metric_agg_types": { + str(k): v.value for k, v in self._metric_agg_types.items() + }, } def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -288,3 +329,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: deserialized_state[metric_key] = local_agg_metric self._metric_states = deserialized_state + + # Restore validation state + self._metric_agg_types = {} + for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): + key = ast.literal_eval(key_str) + self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 5affaca046..f6a7fbf7e2 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -14,7 +14,7 @@ @dataclass(frozen=True) class Metric: dataset_name: str - name: str + metric_name: str value: Union[int, float, str] agg_type: "AggregationType" @@ -101,9 +101,9 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> metrics = transform._generate_metrics(sample) >>> # Creates: >>> # [ - >>> # Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), - >>> # Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), - >>> # Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) + >>> # Metric(dataset_name="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) >>> # ] """ @@ -121,19 +121,19 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: return [ Metric( dataset_name=self.dataset_name, - name="samples_seen", + metric_name="samples_seen", value=1, agg_type=AggregationType.SUM, ), Metric( dataset_name=self.dataset_name, - name="tokens_seen", + metric_name="tokens_seen", value=token_len, agg_type=AggregationType.SUM, ), Metric( dataset_name=self.dataset_name, - name="seq_len", + metric_name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION, ), diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 7b6f790914..bb0f647508 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -234,7 +234,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # also necessary to track dataset-level metrics. metric_num_epochs = Metric( dataset_name=self.info.name, - name="num_epochs", + metric_name="num_epochs", value=self._num_epochs, agg_type=AggregationType.MAX, ) From f89eefe90cf6f78559e054847f9ac92fa0c5ce42 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 19:36:34 -0700 Subject: [PATCH 40/48] improve docs --- .../torchtune/data/test_metrics_aggregator.py | 30 +++++- .../torchtune/data/test_metrics_transform.py | 21 ++++- tests/torchtune/datasets/test_hf_iterable.py | 13 +++ tests/torchtune/datasets/test_interleaved.py | 20 +++- .../data/metrics/_metric_agg_handlers.py | 17 ++-- torchtune/data/metrics/_metric_aggregator.py | 21 +++-- torchtune/data/metrics/_metric_transform.py | 56 ++++++++--- torchtune/datasets/_hf_iterable.py | 93 ++++++++++--------- torchtune/datasets/_interleaved.py | 21 +++-- torchtune/datasets/_iterable_base.py | 57 +++++++++--- 10 files changed, 243 insertions(+), 106 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index c2e8141ff9..db2ab3f617 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,6 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for MetricsAggregator functionality. + +This module tests the metrics collection and aggregation system including: +- All aggregation types (SUM, MEAN, MAX, MIN, DISTRIBUTION, CATEGORICAL_COUNT) +- State management and checkpointing +- Multi-dataset metric namespacing +- Distributed metrics aggregation +- Metric consistency validation + +Uses synthetic metrics to verify correct aggregation behavior across scenarios. +""" + import logging import pytest @@ -15,7 +28,7 @@ class TestMetricsAggregator: - """Focused tests for MetricsAggregator functionality.""" + """Tests for MetricsAggregator core functionality and edge cases.""" @pytest.mark.parametrize( "agg_type,test_values,expected", @@ -32,7 +45,14 @@ class TestMetricsAggregator: ], ) def test_aggregation_types(self, agg_type, test_values, expected): - """Tests each `AggregationType` to ensure it computes the correct value.""" + """Tests each AggregationType with representative data to verify correct computation. + + Covers aggregation types: + - SUM: Simple addition across values + - MEAN: Average computation with proper count tracking + - MAX/MIN: Extrema identification + - CATEGORICAL_COUNT: Category frequency counting + """ aggregator = MetricsAggregator() metrics = [ @@ -52,7 +72,7 @@ def test_aggregation_types(self, agg_type, test_values, expected): assert result["train_test/metric"] == expected def test_distribution_metrics(self): - """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" + """Tests that DISTRIBUTION aggregation computes statistics (mean, min, max, percentiles).""" aggregator = MetricsAggregator() values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -71,8 +91,8 @@ def test_distribution_metrics(self): assert result["train_test/dist_metric_stat_p50"] == 5.5 def test_state_management(self): - """Test aggregator checkpointing and restoration.""" - # Create aggregator with some state + """Test metrics aggregator state persistence and restoration for checkpointing scenarios.""" + # Create aggregator with mixed metric types to test state saving aggregator1 = MetricsAggregator() initial_metrics = [ Metric("ds1", "counter", 10, AggregationType.SUM), diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index eb7a1e951e..ebfb1c81a1 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -4,6 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests cover: +- DefaultTrainingMetricTransform +- Basic metric generation (samples_seen, tokens_seen, seq_len) +- Dataset name validation and requirements +- Proper metric type assignment and aggregation configuration +""" + import pytest from torchtune.data.metrics import AggregationType, DefaultTrainingMetricTransform @@ -13,7 +21,9 @@ class TestDefaultTrainingMetricTransform: """Tests for DefaultTrainingMetricTransform functionality.""" def test_dataset_name_not_set_raises_error(self): - """Test that using transform without setting dataset name raises error.""" + """Test that the transform raises a RuntimeError if used before + `set_dataset_name` is called, ensuring that metrics are always + correctly attributed to a dataset.""" transform = DefaultTrainingMetricTransform() sample = {"tokens": [1, 2, 3]} @@ -21,22 +31,23 @@ def test_dataset_name_not_set_raises_error(self): transform(sample) def test_basic_metrics_generation(self): - """Test that transform generates expected metrics for a sample.""" + """Test that transform generates expected training metrics for input samples.""" transform = DefaultTrainingMetricTransform() + # Set dataset name required for metric generation transform.set_dataset_name("test_dataset") sample = {"tokens": [1, 2, 3, 4, 5]} result = transform(sample) - # Should preserve original sample data + # Transform should preserve original sample data unchanged assert result["tokens"] == [1, 2, 3, 4, 5] - # Should add metrics + # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len assert "metrics" in result metrics = result["metrics"] assert len(metrics) == 3 - # Check each metric + # Verify each metric has correct properties and aggregation type for metric in metrics: if metric.metric_name == "samples_seen": assert metric.dataset_name == "test_dataset" diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 9a54a0cc99..adcede297a 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,6 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for HfIterableDataset core functionality. + +This module tests the foundational iterable dataset capabilities including: +- Basic iteration and data loading +- Epoch boundary handling and tracking +- Shuffling behavior across epochs +- Checkpointing and state restoration +- Distributed training scenarios + +Uses synthetic JSON data with predictable patterns to verify correct behavior. +""" + import math import shutil import tempfile diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index db4ec95035..37bda9adcc 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,6 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for InterleavedDataset functionality. + +This module tests the multi-dataset interleaving capabilities, including: +- Dataset composition with weighted sampling +- Nested interleaving structures +- Metrics collection and aggregation across datasets +- Checkpointing and state restoration +- Distributed training scenarios + +The tests use synthetic JSON data with distinct ID ranges per dataset +to verify correct sampling ratios and data isolation. +""" + import shutil import tempfile from itertools import islice @@ -53,12 +67,14 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None @pytest.fixture def tmp_data_dir(tmp_path): - """Provide temporary directory for test data files.""" + """Provide temporary directory for test data files. + All test datasets are created in this isolated directory to avoid conflicts.""" return tmp_path @pytest.fixture def small_dataset_file(tmp_data_dir): + """Create small dataset (23 samples) with IDs 0-22 for testing basic functionality.""" path = tmp_data_dir / "small_data.json" create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) return str(path) @@ -66,6 +82,7 @@ def small_dataset_file(tmp_data_dir): @pytest.fixture def medium_dataset_file(tmp_data_dir): + """Create medium dataset (35 samples) with IDs 100-134 for multi-dataset testing.""" path = tmp_data_dir / "medium_data.json" create_test_json_file(path, MEDIUM_DATASET_SIZE, offset=100) return str(path) @@ -73,6 +90,7 @@ def medium_dataset_file(tmp_data_dir): @pytest.fixture def large_dataset_file(tmp_data_dir): + """Create large dataset (47 samples) with IDs 1000-1046 for nested interleaving tests.""" path = tmp_data_dir / "large_data.json" create_test_json_file(path, LARGE_DATASET_SIZE, offset=1000) return str(path) diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index d5c4122228..ac3f9a2fd7 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -38,15 +38,14 @@ class MetricState: class AggregationHandler(ABC): - """Base class for handling metric aggregation in MetricsAggregator. - - This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). - Each handler is responsible for: - - Initializing the state for a new (dataset, metric) pair. - - Updating the state with new values. - - Finalizing the value for local (single-rank) logging. - - Reducing the values from all ranks in a distributed setting. - - Serializing and deserializing the metric state for checkpointing. + """Base class for handling metric aggregation using the Strategy pattern. + + Each handler implements a specific aggregation strategy (SUM, MEAN, DISTRIBUTION, etc.) + and manages the complete lifecycle: initialization, updates, local finalization, + and distributed reduction. Handlers also handle serialization for checkpointing. + + The handler architecture allows pluggable aggregation strategies while maintaining + consistent interfaces for the MetricsAggregator. """ @abstractmethod diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index cb4f78abf0..da6b152350 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -29,14 +29,22 @@ class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) - has its own handler. Maintains only one state per (dataset, metric) pair. + This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) + has a corresponding AggregationHandler. It maintains a single state object for each + (dataset, metric) pair. - When preparing for logging, uses a two-phase approach: - 1. Local aggregation: Each rank aggregates its metrics independently - 2. Distributed reduction: Results combined across ranks + Internal State Visualization: + { + ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), + ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), + ("slim_orca", "seq_len"): MetricState(agg_type=DISTRIBUTION, metadata={'values': deque([...])}), + } - The aggregator is checkpointable and restores from state_dict for training resumption. + When preparing metrics for logging, the aggregator follows a two-phase process: + 1. Local Aggregation: Each rank aggregates its metrics independently + 2. Distributed Reduction: If in distributed mode, results are combined across ranks + + The aggregator's state is checkpointable, allowing training resumption. Args: dist_window_size (int): Window size for DistributionAggHandler tracking. @@ -44,7 +52,6 @@ class MetricsAggregator: Example: >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType >>> - >>> # Create aggregator >>> aggregator = MetricsAggregator() >>> >>> # Sample metrics from different batches diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index f6a7fbf7e2..8521f6e6dd 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -20,7 +20,11 @@ class Metric: class AggregationType(Enum): - """Defines how a metric's value should be aggregated.""" + """Defines how a metric's value should be aggregated by the MetricsAggregator. + + Each type corresponds to a specific AggregationHandler that implements the logic + for initialization, updates, and distributed reduction. + """ SUM = "sum" MEAN = "mean" @@ -33,22 +37,33 @@ class AggregationType(Enum): class MetricTransform(Transform): """Applied to each dataset sample to generate per-sample metrics for training tracking. - Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + Creates Metric objects that are later aggregated by MetricsAggregator. This separation of concerns ensures metrics are correctly aggregated even with multiple dataloader - workers and in distributed settings.""" + workers and in distributed settings. + + The transform must be configured with a dataset name via set_dataset_name() before use. + Each call to __call__ adds metrics to the sample's "metrics" key. + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> sample = {"tokens": [1, 2, 3]} + >>> result = transform(sample) + >>> # result["metrics"] contains list of Metric objects + """ def __init__(self): # dataset_name is set by the dataset using set_dataset_name self.dataset_name: Optional[str] = None def set_dataset_name(self, dataset_name: str) -> None: - """Called by dataset to set the namespace for metrics. + """Called by the dataset to set the namespace for metrics. - The dataset name is used to differentiate multiple datasets stats, - e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". + This is used to differentiate metrics from multiple datasets, for example, + "train_alpaca/tokens_seen" vs. "train_slim_orca/tokens_seen". Args: - dataset_name (str): Name of the dataset for metric namespacing + dataset_name (str): Name of the dataset, used for metric namespacing. """ self.dataset_name = dataset_name @@ -67,7 +82,17 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: raise NotImplementedError("Subclasses must implement _generate_metrics method") def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply transform to sample, adding generated metrics.""" + """Apply transform to sample, adding generated metrics to the sample. + + Args: + sample (dict[str, Any]): Input sample dictionary + + Returns: + dict[str, Any]: Sample with metrics added to "metrics" key (list[Metric]) + + Raises: + RuntimeError: If set_dataset_name() was not called before transform usage + """ if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -84,14 +109,17 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: class DefaultTrainingMetricTransform(MetricTransform): - """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. + """Generates common training metrics: samples seen, tokens seen, and sequence length. + + This transform detects the token key in a sample, checking for "tokens" + first and then falling back to "input_ids". - For details about MetricTransform base class behavior, see the parent class docstring. + For details on the base class behavior, see MetricTransform. Tracked metrics: - - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) Example: >>> transform = DefaultTrainingMetricTransform() @@ -99,7 +127,7 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens >>> metrics = transform._generate_metrics(sample) - >>> # Creates: + >>> # This generates the following Metric objects: >>> # [ >>> # Metric(dataset_name="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), >>> # Metric(dataset_name="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index bb0f647508..f517fece31 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -24,34 +24,42 @@ class HfIterableDataset(InfiniteTuneIterableDataset): - """HuggingFace dataset implementation with composable metrics. + """HuggingFace dataset with infinite iteration and composable transforms. - This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. + This is an infinite dataset that wraps a HuggingFace dataset. After exhausting + the dataset, it will restart from the beginning. + + Transform pipeline: raw_data -> message_transform -> model_transform -> output_transform -> metric_transform This dataset is responsible for: - Loading and sharding the dataset - Shuffling at initialization and after each epoch - - Applying transforms + - Applying transforms to the data - Returning an infinite iterator over the dataset - Args: - message_transform (Optional[Callable]): Transforms raw data into Message - model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. - output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually - does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. - metric_transform (Optional[MetricTransform]): Takes the sample and computes metrics, e.g. token count. - If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. - shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. - seed (int): Seed for shuffling. - num_shards_per_rank (int): Target number of shards per worker (GPU). It will find a multiple - of world_size * dataloader_workers. - dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated - from the path, source, and split. + Args: + message_transform (Optional[Callable]): Transforms raw data into a `Message`. + model_transform (Optional[Callable]): Prepares messages for the model, + usually by tokenizing them. + output_transform (Optional[Callable]): Prepares tokenized inputs for the + recipe, often by manipulating labels (e.g., setting an ignore index). + This transform is recipe-dependent (e.g., SFT, DPO, etc.). + metric_transform (Optional[MetricTransform]): Computes metrics from a + sample (e.g., token count). If ``None``, a default transform is used. + To disable standard metric tracking, set this to ``lambda x: x``. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. + If ``None`` or 0, no shuffling is performed. weight (Optional[float]): Weight for this dataset. Defaults to 1.0. - filter_fn (Optional[Callable]): Filter function to apply to the dataset. - filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. - load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. - + seed (int): Seed for shuffling. + num_shards_per_rank (int): The target number of shards per worker (GPU). + The actual number of shards will be a multiple of + ``world_size * dataloader_workers``. + dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is + generated from the ``path``, ``source``, and ``split``. + filter_fn (Optional[Callable]): A function to filter the dataset. + filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``. + **load_dataset_kwargs: Keyword arguments for the + :func:`~datasets.load_dataset` function. """ def __init__( @@ -132,9 +140,10 @@ def _setup_hf_dataset( filter_kwargs: Optional[dict[str, Any]] = None, ): """ - Configures the Hugging Face dataset, including sharding, filtering, and - transform mapping. This method is called only once during initialization - to avoid expensive re-computation on each epoch. + One-time setup of HuggingFace dataset that handles Handles distributed sharding, + shuffle configuration, and filtering. + + Called once during __init__ to avoid expensive re-computation. """ # Distributed setup @@ -165,13 +174,13 @@ def _setup_hf_dataset( worker_info = torch.utils.data.get_worker_info() num_dataloader_workers = worker_info.num_workers if worker_info else 1 - # Calculate total workers + # Calculate total workers across all ranks and dataloader processes total_workers = world_size * num_dataloader_workers - # Calculate desired shards + # Find minimum shards that satisfies our target while being divisible by workers desired_shards = world_size * num_shards_per_rank - # Find the smallest multiple of total_workers that is >= desired_shards + # Round up to next multiple of total_workers for even distribution if desired_shards % total_workers == 0: num_shards = desired_shards else: @@ -207,14 +216,14 @@ def _setup_hf_dataset( self._ds = ds def __iter__(self) -> Iterator[dict[str, Any]]: - """Iterate through the dataset infinitely. - - It will restart from the beginning after exhausting the dataset. - - If shuffle_buffer_size is set, it will shuffle the dataset at the beginning of each epoch - when set_epoch is called. - - An additional metric "num_epochs" is added to the sample. + """Infinite iteration over dataset samples. + + Behavior: + - Restarts from beginning when dataset is exhausted + - Reshuffles at start of each epoch (if enabled) + - Applies full transform pipeline to each sample + - Adds 'num_epochs' metric to track dataset progress + - Yields samples indefinitely for continuous training """ while True: # Infinite iteration @@ -224,9 +233,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: try: for sample in epoch_iterator: - # NOTE: We apply transforms here instead of using .map() call - # to work around https://github.com/huggingface/datasets/issues/7630 - # where .map() can cause incorrect resumption from a checkpoint. + # NOTE: We apply transforms here instead of using .map() to work around + # HuggingFace datasets bug where .map() causes incorrect checkpoint resumption. + # See: https://github.com/huggingface/datasets/issues/7630 + # This ensures transforms are applied fresh on each sample during iteration. sample = self._apply_transforms(sample) # Track the number of epochs completed for each dataset. This is @@ -246,9 +256,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: yield sample except StopIteration: - pass # Iterator is exhausted, which is expected. + # Expected when dataset is exhausted + pass except Exception as e: - logger.warning( + logger.error( f"Dataset {self.info.name} encountered an unexpected error: {e}." ) raise @@ -263,9 +274,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: self._num_epochs += 1 def state_dict(self) -> dict[str, Any]: - """ - The dataset returns its own state directly, without namespacing. - """ + """Returns dataset checkpoint state.""" hf_state = self._ds.state_dict() state = { "num_epochs": self._num_epochs, diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 2267696ef4..fe911aff51 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -17,17 +17,18 @@ class InterleavedDataset(InfiniteTuneIterableDataset): - """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. - - The weights are extracted from each dataset's info.weight property and normalized to sum to 1.0. - - This dataset is responsible for managing the state of its child datasets - to ensure correct checkpointing and resumption. + """Infinitely interleaves multiple datasets according to their sampling weights. + + The weights are extracted from each dataset's ``info.weight`` property and + normalized to sum to 1.0. This dataset manages the state of its child + datasets to ensure correct checkpointing and resumption. Args: - datasets (list[InfiniteTuneIterableDataset]): list of datasets to interleave. - seed (int): Seed for sampling. - weight (float): Weight for this dataset. Defaults to 1.0. - dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". - sampling_log_maxlen (int): Maximum length of the sampling log. + datasets (list[InfiniteTuneIterableDataset]): A list of datasets to interleave. + seed (int): The seed for sampling. + weight (float): The weight for this dataset. Defaults to 1.0. + dataset_name (str): The name of the dataset. Defaults to "interleaved_dataset". + sampling_log_maxlen (int): The maximum length of the sampling log. """ def __init__( @@ -100,7 +101,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: yield next(child_iters[ds_name]) def state_dict(self) -> dict[str, Any]: - """Save state for the interleaver and its children.""" + """Save interleaver state and all child dataset states.""" # The parent is responsible for namespacing the child states child_states = {ds.info.name: ds.state_dict() for ds in self._datasets} return { diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index a26c22eafa..0f412e80dc 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -13,9 +13,33 @@ @dataclass(frozen=True) class DatasetInfo: - """Represents hierarchical information about a dataset, including its name, - sampling weight and children. Children is a common case when composing datasets, - e.g. Packed(InterleavedDataset([ds1, ds2])). + """Hierarchical metadata for datasets, enabling composition and weight tracking. + + Used to build tree structures when composing datasets. For example, a nested + `InterleavedDataset` dataset would have this structure: + + Example: + .. code-block:: python + + DatasetInfo(name='parent_interleaved', + weight=1.0, + children=(DatasetInfo(name='child_interleaved', + weight=0.7, + children=(DatasetInfo(name='dataset_a', + weight=0.6, + children=()), + DatasetInfo(name='dataset_b', + weight=0.4, + children=()))), + DatasetInfo(name='dataset_c', weight=0.3, children=()))) + + This hierarchical structure is used for validation (ensuring unique dataset + names) and for logging metrics. + + Attributes: + name (str): Unique identifier for the dataset + weight (float): Sampling weight for dataset selection (default: 1.0) + children (tuple[DatasetInfo, ...]): Nested datasets for composed structures """ name: str @@ -24,10 +48,16 @@ class DatasetInfo: class TuneIterableDataset(IterableDataset, ABC): - """Abstract base class for all torchtune iterable datasets. - It defines the minimal, consistent interface required for all dataset - implementations to ensure they are compatible with the training loop, - checkpointing, and metric logging systems. + """Base class for all torchtune iterable datasets. + + Datasets are composable, enabling complex structures such as: + ``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))`` + + Each dataset implementation must: + - Track hierarchical metadata via the ``info`` property + - Ensure unique dataset names across the entire tree + - Handle checkpointing: parents resume children's state + - Provide proper state management for exact resumption """ @property @@ -64,19 +94,20 @@ def __iter__(self) -> Iterator[dict[str, Any]]: @abstractmethod def state_dict(self) -> dict[str, Any]: - """Returns a state dictionary for checkpointing""" + """Returns checkpoint state for dataset resumption.""" pass @abstractmethod def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load state from a state dictionary, used when resuming from a checkpoint.""" + """Restores dataset state from checkpoint.""" pass class InfiniteTuneIterableDataset(TuneIterableDataset): - """Abstract base class for infinite datasets, which yield samples indefinitely. - It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. - it never exhausts. This is helpful to avoid complexity due to some rank hanging because - of lack of data""" + """Base class for infinite datasets that never exhaust. + + Prevents distributed training hangs by ensuring all ranks always + have data available. Datasets restart from beginning when exhausted. + """ pass From d6680b738047a3f4c8c3a08f493bec1ac52ccd8b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 23:09:58 -0400 Subject: [PATCH 41/48] rename from strategy to packer --- recipes/configs/llama3_2/3B_full.yaml | 4 +-- recipes/full_finetune_distributed.py | 6 ++-- torchtune/data/_collate.py | 37 +++++++++++++++++++----- torchtune/datasets/__init__.py | 4 +-- torchtune/datasets/_iterable_packed.py | 40 +++++++++++++------------- 5 files changed, 57 insertions(+), 34 deletions(-) diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 5e9ca999ec..57e1fd45fa 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -42,10 +42,10 @@ dataset: weight: 0.2 split: train[:5%] # simulate 1 epoch quickly -# On-the-fly packing strategy +# On-the-fly packing # Set packing_strategy: null to disable packing packing_strategy: - _component_: torchtune.datasets.TextPackingStrategy + _component_: torchtune.datasets.TextPacker seed: 42 diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fa4211b7ed..c4ad8886f5 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -802,7 +802,7 @@ def _setup_data( base_collate_fn = _get_component_from_path(collate_fn) if cfg_packing_strategy: - packing_strategy = config.instantiate( + packer = config.instantiate( cfg_packing_strategy, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, @@ -810,13 +810,13 @@ def _setup_data( ds = IterablePackedDataset( dataset=ds, - strategy=packing_strategy, + packer=packer, target_tokens_per_pack=self._tokenizer.max_seq_len, ) base_collate_fn = partial( base_collate_fn, - mask_fn=packing_strategy.create_block_mask, + mask_fn=packer.create_block_mask, device=self._device, ) diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 6faa88774f..7f97cbeb12 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -13,25 +13,48 @@ def collate_packed( - batch: list[dict[str, torch.Tensor]], mask_fn: callable, device: str -) -> dict[str, torch.Tensor]: + batch: list[dict[str, Any]], mask_fn: callable, device: str +) -> dict[str, Any]: """ Generic collate function for packed samples from an IterablePackedDataset. - This function handles tensor stacking and delegates attention mask creation - to a provided `mask_fn`. + Stacks tensors from all samples in the batch, while keeping non-tensor values + as lists. Handles metrics by extending them into a single list. Delegates + attention mask creation to a provided `mask_fn` callable that expects + `document_ids` and `device` parameters to generate masks on-the-fly for + packed sequences. + + Args: + batch (list[dict[str, Any]]): A list of dictionaries containing samples. + mask_fn (callable): A function that generates attention masks for packed sequences. + device (str): The device to use for the tensors. + + Returns: + dict[str, Any]: A dictionary containing the collated samples. + + Raises: + ValueError: If all samples do not have the same keys. """ if not batch: return {} - # Assumes all samples in the batch have the same keys, which are all tensors. - keys_to_stack = batch[0].keys() + # Verify all samples have the same keys + first_sample_keys = batch[0].keys() + for sample in batch: + if sample.keys() != first_sample_keys: + raise ValueError(f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}") + + keys_to_stack = first_sample_keys collated = {} + for key in keys_to_stack: if isinstance(batch[0][key], torch.Tensor): collated[key] = torch.stack([sample[key] for sample in batch], dim=0) + elif key == "metrics": + collated[key] = [] + for sample in batch: + collated[key].extend(sample[key]) else: - # TODO: Remove? i dont see a situation where it would not be a tensor. collated[key] = [sample[key] for sample in batch] # Delegate mask creation to the provided specialized function diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 01dce154da..c523e75a8f 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -20,7 +20,7 @@ from torchtune.datasets._interleaved import InterleavedDataset from torchtune.datasets._iterable_packed import ( IterablePackedDataset, - TextPackingStrategy, + TextPacker, ) from torchtune.datasets._iterable_base import ( DatasetInfo, @@ -68,5 +68,5 @@ "TuneIterableDataset", "wikitext_dataset", "IterablePackedDataset", - "TextPackingStrategy", + "TextPacker", ] diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 7687673ad6..85c650453a 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -26,15 +26,15 @@ PackType = dict[str, torch.Tensor | list[Metric]] -class PackingStrategy(ABC, Generic[SampleType]): +class Packer(ABC, Generic[SampleType]): """ - Strategy to be used in IterablePackedDataset and with FlexAttention. + Packer to be used in IterablePackedDataset and with FlexAttention. """ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): if not _SUPPORTS_FLEX_ATTENTION: raise RuntimeError( - "The IterablePackedDataset and its strategies require Flex Attention support, " + "The IterablePackedDataset and its packers require Flex Attention support, " "which is not available in the current environment." ) self.padding_idx = padding_idx @@ -43,7 +43,7 @@ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX) @abstractmethod def set_dataset_name(self, dataset_name: str) -> None: """ - Sets the dataset name on the strategy. + Sets the dataset name on the packer. Args: dataset_name (str): The name of the dataset. @@ -154,7 +154,7 @@ def _mask_mod( The core logic for the block attention mask, to be passed to `torch.nn.attention.flex_attention.create_block_mask`. - This method is implemented by each strategy to define the specific + This method is implemented by each packer to define the specific attention pattern (e.g., standard causal, DPO, etc.). Args: @@ -191,12 +191,12 @@ class IterablePackedDataset( TuneIterableDataset[PackType], Stateful, Generic[SampleType] ): """ - IterablePackedDataset takes any TuneIterableDataset and a PackingStrategy, packs documents until + IterablePackedDataset takes any TuneIterableDataset and a Packer, packs documents until the 'target_tokens_per_pack' is reached and yields a dictionary of tensors. Args: dataset (TuneIterableDataset[SampleType]): The TuneIterableDataset to pack. - strategy (PackingStrategy[SampleType]): The PackingStrategy to use for packing. + packer (Packer[SampleType]): The Packer to use for packing. target_tokens_per_pack (int): The target number of tokens per pack. buffer_size (int): The size of the buffer to use for packing. dataset_name (str): The name of the dataset. If None, a defaults to IterablePackedDataset. @@ -205,19 +205,19 @@ class IterablePackedDataset( def __init__( self, dataset: TuneIterableDataset[SampleType], - strategy: PackingStrategy[SampleType], + packer: Packer[SampleType], target_tokens_per_pack: int, buffer_size: int = 50, dataset_name: str = "IterablePackedDataset", ): self.dataset = dataset - self.strategy = strategy + self.packer = packer self.target_tokens_per_pack = target_tokens_per_pack self.buffer_size = buffer_size self._dataset_name = dataset_name - # Set dataset name on the strategy - self.strategy.set_dataset_name(dataset_name) + # Set dataset name on the packer + self.packer.set_dataset_name(dataset_name) self._reset_packer_state() @@ -260,7 +260,7 @@ def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: while len(self._buffer) < self.buffer_size and not self._exhausted: try: sample = next(iterator) - sample_size = self.strategy.get_sample_size(sample) + sample_size = self.packer.get_sample_size(sample) # Drop samples that are too large if sample_size > self.target_tokens_per_pack: @@ -314,7 +314,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[PackType]: """ # Start a new pack if necessary if self._current_pack is None: - self._current_pack = self.strategy.create_empty_pack() + self._current_pack = self.packer.create_empty_pack() self._current_pack_size = 0 self._current_doc_id_in_pack = 0 @@ -328,7 +328,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[PackType]: if selected_sample_idx is not None: sample, sample_size = self._buffer[selected_sample_idx] del self._buffer[selected_sample_idx] - docs_consumed = self.strategy.add_sample_to_pack( + docs_consumed = self.packer.add_sample_to_pack( self._current_pack, sample, self._current_doc_id_in_pack ) self._current_doc_id_in_pack += docs_consumed @@ -339,7 +339,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[PackType]: # If the pack has any content, finalize and return it if self._current_pack_size > 0: - final_pack = self.strategy.finalize_pack( + final_pack = self.packer.finalize_pack( self._current_pack, self.target_tokens_per_pack, self._current_doc_id_in_pack, @@ -410,9 +410,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._resuming = True -class TextPackingStrategy(PackingStrategy[dict[str, list[int]]]): +class TextPacker(Packer[dict[str, list[int]]]): """ - Strategy for packing standard text samples for causal language modeling. It is designed + Packer for packing standard text samples for causal language modeling. It is designed to be used with the IterablePackedDataset. - Each sample is treated as a separate document. - `input_pos` restarts from 0 for each sample. @@ -508,12 +508,12 @@ def _mask_mod( # NOTE: For demonstration purposes only. -class DPOPackingStrategy(PackingStrategy[dict[str, list[int]]]): +class DPOPacker(Packer[dict[str, list[int]]]): """ - Strategy for packing DPO samples with a shared prompt. It packs a DPO + Packer for packing DPO samples with a shared prompt. It packs a DPO sample as three logical documents: a shared prompt, a chosen response, and a rejected response. This structure is encoded in the `document_ids` - metadata, allowing the strategy to build the correct attention pattern + metadata, allowing the packer to build the correct attention pattern (e.g., both responses can attend to the prompt, but not to each other). ASSUMPTION: The input DPO sample dict contains pre-tokenized: From d3be015e69039deb6fcc7b3db9eb25fabe812760 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 14:15:02 -0400 Subject: [PATCH 42/48] tensors instead of lists --- planning/ontheflypacking.md | 52 ----- torchtune/datasets/_iterable_packed.py | 290 +++++++++++++++---------- 2 files changed, 178 insertions(+), 164 deletions(-) delete mode 100644 planning/ontheflypacking.md diff --git a/planning/ontheflypacking.md b/planning/ontheflypacking.md deleted file mode 100644 index 18279bac47..0000000000 --- a/planning/ontheflypacking.md +++ /dev/null @@ -1,52 +0,0 @@ -### What: -Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask. - -Example: -```python -# The current pack with one sample -pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} - -# The next sample to be added -sample = {"tokens": [5, 6], "labels": [7, 8]} - -# After adding the sample -added_docs = add_sample_to_pack(pack, sample, next_doc_id=1) -print(pack) ->>> {"tokens": [1, 2, 5, 6], - "labels": [3, 4, 7, 8], - "document_ids": [0, 0, 1, 1], - "input_pos": [0, 1, 0, 1]} - -create_block_causal_mask(document_ids) ->>> [ - [1, 0, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 1, 1], - ] -``` - -### Goal: -0) Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes. - -### Context: -1) We currently have map-style packing. We pre-process the dataset before training, which is not scalable. -2) Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc. -3) Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else. - -### Solution: -4) Implement a new on-the-fly packing that takes any iterable dataset as input; -5) Packing contract consists of - i) a `PackingStrategy` that defines a) how to pack and b) the **_mask_mod** used for flex attention; - ii) a `IterablePackedDataset` that takes any a) `PackingStrategy`, b) **iterable dataset** as input and yields packed samples; - iii) a `packed_collate_fn` that takes the batch of packed samples and a **mask_fn** (e.g. `strategy.create_block_mask`) to generate the attention mask on the fly. - To define a new packing strategy, the user only needs to implement the `PackingStrategy` class. - -### Implementation: -6) Updated `full_finetune_distributed.py` to use `IterablePackedDataset` when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC. - -### Not in this PR: -7) **Logging**: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.) -8) **Packing-aware Loss**: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it. -9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing; -10) **tokenization**: For advanced packing, e.g. shared prompts in GRPO/DPO, we will need extra metadata from upstream datasets, e.g. prompt len. \ No newline at end of file diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 85c650453a..dbbeb3d54c 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -53,7 +53,7 @@ def set_dataset_name(self, dataset_name: str) -> None: @abstractmethod def create_empty_pack(self) -> dict[str, list[Any]]: """ - Creates an empty pack. + Creates an empty pack with lists that will hold tensors. Returns: dict[str, list[Any]]: An empty dictionary with lists as values. @@ -87,7 +87,7 @@ def add_sample_to_pack( self, pack: dict[str, list[Any]], sample: SampleType, next_doc_id: int ) -> int: """ - Adds a sample to the pack dictionary in-place. + Adds a sample to the pack dictionary in-place by appending tensors to lists. Args: pack (dict[str, list[Any]]): The dictionary representing the pack, to be modified in-place. @@ -98,14 +98,12 @@ def add_sample_to_pack( int: The number of new documents that were added to the pack. Example: - pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} - sample = {"tokens": [5, 6], "labels": [7, 8]} + pack = {"tokens": [tensor([1, 2])], "labels": [tensor([3, 4])], ...} + sample = {"tokens": tensor([5, 6]), "labels": tensor([7, 8])} added_docs = self.add_sample_to_pack(pack, sample, next_doc_id=1) print(pack) - >>> {"tokens": [1, 2, 5, 6], - "labels": [3, 4, 7, 8], - "document_ids": [0, 0, 1, 1], - "input_pos": [0, 1, 0, 1]} + >>> {"tokens": [tensor([1, 2]), tensor([5, 6])], + "labels": [tensor([3, 4]), tensor([7, 8])], ...} print(added_docs) >>> 1 """ @@ -116,28 +114,26 @@ def finalize_pack( self, pack: dict[str, list[Any]], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: """ - Finalizes a pack, primarily by padding it to the target length. + Finalizes a pack by padding to target length and concatenating tensor lists. Args: - pack (dict[str, list[Any]]): The pack data. + pack (dict[str, list[Any]]): The pack data containing lists of tensors. target_tokens_per_pack (int): The target length to pad to. next_doc_id (int): The document ID to use for the padding tokens. Returns: - PackType: The finalized pack. + PackType: The finalized pack with concatenated tensors. Example: - pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} + pack = {"tokens": [tensor([1, 2])], "labels": [tensor([3, 4])], ...} target_tokens_per_pack = 4 next_doc_id = 1 self.padding_idx = 999 self.ignore_idx = -100 self.finalize_pack(pack, target_tokens_per_pack, next_doc_id) - >>> {"tokens": [1, 2, 999, 999], - "labels": [3, 4, -100, -100], - "document_ids": [0, 0, 1, 1], - "input_pos": [0, 1, 0, 1]} + >>> {"tokens": tensor([1, 2, 999, 999]), + "labels": tensor([3, 4, -100, -100]), ...} """ pass @@ -196,7 +192,7 @@ class IterablePackedDataset( Args: dataset (TuneIterableDataset[SampleType]): The TuneIterableDataset to pack. - packer (Packer[SampleType]): The Packer to use for packing. + packer (Packer[SampleType]): The Packer specific to the dataset format (e.g. text, DPO, etc.). target_tokens_per_pack (int): The target number of tokens per pack. buffer_size (int): The size of the buffer to use for packing. dataset_name (str): The name of the dataset. If None, a defaults to IterablePackedDataset. @@ -410,7 +406,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._resuming = True -class TextPacker(Packer[dict[str, list[int]]]): +class TextPacker(Packer[dict[str, torch.Tensor]]): """ Packer for packing standard text samples for causal language modeling. It is designed to be used with the IterablePackedDataset. @@ -428,6 +424,7 @@ def set_dataset_name(self, dataset_name: str) -> None: self.dataset_name = dataset_name def create_empty_pack(self) -> dict[str, list]: + """Creates an empty pack with lists that will hold tensors.""" return { "tokens": [], "labels": [], @@ -436,53 +433,74 @@ def create_empty_pack(self) -> dict[str, list]: "metrics": [], } - def get_sample_size(self, sample: dict[str, list[int]]) -> int: - return len(sample["tokens"]) + def get_sample_size(self, sample: dict[str, torch.Tensor]) -> int: + """Returns the number of tokens in the sample.""" + return sample["tokens"].numel() def add_sample_to_pack( - self, pack: dict[str, list], sample: dict[str, list[int]], next_doc_id: int + self, pack: dict[str, list], sample: dict[str, torch.Tensor], next_doc_id: int ) -> int: - seq_len = len(sample["tokens"]) - - # Append sample data to the pack - pack["tokens"].extend(sample["tokens"]) - pack["labels"].extend(sample["labels"]) - pack["document_ids"].extend([next_doc_id] * seq_len) - pack["input_pos"].extend(range(seq_len)) # input_pos restarts for each doc + """Adds a tensor sample to the pack by appending tensors to lists.""" + seq_len = sample["tokens"].numel() + + # Append tensors directly to pack lists + pack["tokens"].append(sample["tokens"]) + pack["labels"].append(sample["labels"]) + + # Generate metadata as tensors + pack["document_ids"].append( + torch.full((seq_len,), next_doc_id, dtype=torch.long, device="cpu") + ) + # input_pos restarts from 0 for each document + pack["input_pos"].append(torch.arange(seq_len, dtype=torch.long, device="cpu")) # Handle metrics if they exist in the sample if "metrics" in sample: pack["metrics"].extend(sample["metrics"]) - # Increment doc ID for the next sample + # return number of documents added return 1 def finalize_pack( self, pack: dict[str, list], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: - current_size = len(pack["tokens"]) + """Finalizes pack by padding and concatenating tensor lists efficiently.""" + # Calculate current size from tensor list + current_size = sum(t.numel() for t in pack["tokens"]) if pack["tokens"] else 0 num_padding = target_tokens_per_pack - current_size + # Add padding tensors if needed if num_padding > 0: - pack["tokens"].extend([self.padding_idx] * num_padding) - pack["labels"].extend([self.ignore_idx] * num_padding) - pack["input_pos"].extend([0] * num_padding) - pack["document_ids"].extend([next_doc_id] * num_padding) - - # Add pct_of_tokens_padded metric - padding_metric = Metric( - dataset_name=self.dataset_name, - name="pct_of_tokens_padded", - value=round(num_padding * 100 / len(pack["tokens"]), 2), - agg_type=AggregationType.MEAN, - ) - pack["metrics"].append(padding_metric) + pack["tokens"].append( + torch.full((num_padding,), self.padding_idx, dtype=torch.long) + ) + pack["labels"].append( + torch.full((num_padding,), self.ignore_idx, dtype=torch.long) + ) + pack["document_ids"].append( + torch.full((num_padding,), next_doc_id, dtype=torch.long) + ) + pack["input_pos"].append( + torch.zeros(num_padding, dtype=torch.long) + ) + + # Add padding percentage metric + if target_tokens_per_pack > 0: + padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) + padding_metric = Metric( + dataset_name=self.dataset_name, + name="pct_of_tokens_padded", + value=padding_pct, + agg_type=AggregationType.MEAN, + ) + pack["metrics"].append(padding_metric) + # Concatenate all tensor lists efficiently result = { - "tokens": torch.tensor(pack["tokens"], dtype=torch.long), - "labels": torch.tensor(pack["labels"], dtype=torch.long), - "document_ids": torch.tensor(pack["document_ids"], dtype=torch.long), - "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), + "tokens": torch.cat(pack["tokens"]) if pack["tokens"] else torch.empty(0, dtype=torch.long), + "labels": torch.cat(pack["labels"]) if pack["labels"] else torch.empty(0, dtype=torch.long), + "document_ids": torch.cat(pack["document_ids"]) if pack["document_ids"] else torch.empty(0, dtype=torch.long), + "input_pos": torch.cat(pack["input_pos"]) if pack["input_pos"] else torch.empty(0, dtype=torch.long), "metrics": pack["metrics"], } @@ -508,7 +526,7 @@ def _mask_mod( # NOTE: For demonstration purposes only. -class DPOPacker(Packer[dict[str, list[int]]]): +class DPOPacker(Packer[dict[str, torch.Tensor]]): """ Packer for packing DPO samples with a shared prompt. It packs a DPO sample as three logical documents: a shared prompt, a chosen response, @@ -516,7 +534,7 @@ class DPOPacker(Packer[dict[str, list[int]]]): metadata, allowing the packer to build the correct attention pattern (e.g., both responses can attend to the prompt, but not to each other). - ASSUMPTION: The input DPO sample dict contains pre-tokenized: + ASSUMPTION: The input DPO sample dict contains pre-tokenized tensors: - "prompt_ids" - "chosen_response_only_ids" - "chosen_response_only_labels" @@ -533,6 +551,7 @@ def set_dataset_name(self, dataset_name: str) -> None: self.dataset_name = dataset_name def create_empty_pack(self) -> dict[str, list]: + """Creates an empty pack with lists that will hold tensors.""" return { "tokens": [], "labels": [], @@ -543,95 +562,142 @@ def create_empty_pack(self) -> dict[str, list]: "metrics": [], } - def get_sample_size(self, sample: dict[str, list[int]]) -> int: - # The total size of one DPO sample is the shared prompt + both responses. + def get_sample_size(self, sample: dict[str, torch.Tensor]) -> int: + """Returns total size of DPO sample: prompt + both responses.""" return ( - len(sample["prompt_ids"]) - + len(sample["chosen_response_only_ids"]) - + len(sample["rejected_response_only_ids"]) + sample["prompt_ids"].numel() + + sample["chosen_response_only_ids"].numel() + + sample["rejected_response_only_ids"].numel() ) def add_sample_to_pack( - self, pack: dict[str, list], sample: dict[str, list[int]], next_doc_id: int + self, pack: dict[str, list], sample: dict[str, torch.Tensor], next_doc_id: int ) -> int: - # Assign a unique doc ID triplet for (prompt, chosen, rejected) - prompt_doc_id = next_doc_id - chosen_doc_id = next_doc_id + 1 - rejected_doc_id = next_doc_id + 2 - + """Adds a DPO sample by building complete tensors and appending to pack lists.""" prompt_ids = sample["prompt_ids"] chosen_ids = sample["chosen_response_only_ids"] rejected_ids = sample["rejected_response_only_ids"] - # Input positions restart from 0 for each new DPO sample in the pack - total_len = len(prompt_ids) + len(chosen_ids) + len(rejected_ids) - pack["input_pos"].extend(range(total_len)) - - # 1. Add Shared Prompt data - pack["tokens"].extend(prompt_ids) - pack["labels"].extend([self.ignore_idx] * len(prompt_ids)) - pack["document_ids"].extend([prompt_doc_id] * len(prompt_ids)) - pack["chosen_response_mask"].extend([False] * len(prompt_ids)) - pack["rejected_response_mask"].extend([False] * len(prompt_ids)) - - # 2. Add Chosen Response data - pack["tokens"].extend(chosen_ids) - pack["labels"].extend(sample["chosen_response_only_labels"]) - pack["document_ids"].extend([chosen_doc_id] * len(chosen_ids)) - pack["chosen_response_mask"].extend([True] * len(chosen_ids)) - pack["rejected_response_mask"].extend([False] * len(chosen_ids)) - - # 3. Add Rejected Response data - pack["tokens"].extend(rejected_ids) - pack["labels"].extend(sample["rejected_response_only_labels"]) - pack["document_ids"].extend([rejected_doc_id] * len(rejected_ids)) - pack["chosen_response_mask"].extend([False] * len(rejected_ids)) - pack["rejected_response_mask"].extend([True] * len(rejected_ids)) + prompt_len = prompt_ids.numel() + chosen_len = chosen_ids.numel() + rejected_len = rejected_ids.numel() + + # 1. Add Shared Prompt, Chosen Response, and Rejected Response tokens + complete_tokens = torch.cat([prompt_ids, chosen_ids, rejected_ids]) + # 2. Add labels for all three parts + complete_labels = torch.cat( + [ + torch.full( + (prompt_len,), self.ignore_idx, dtype=torch.long), + sample["chosen_response_only_labels"], + sample["rejected_response_only_labels"], + ] + ) + + # 3. Create document IDs: prompt(next_doc_id), chosen(next_doc_id+1), rejected(next_doc_id+2) + complete_doc_ids = torch.cat( + [ + torch.full( + (prompt_len,), next_doc_id, dtype=torch.long), + torch.full( + (chosen_len,), next_doc_id + 1, dtype=torch.long), + torch.full( + (rejected_len,), next_doc_id + 2, dtype=torch.long), + ] + ) + + # 4. Create input positions (restarts from 0 for each DPO sample) + total_len = complete_tokens.numel() + complete_input_pos = torch.arange(total_len, dtype=torch.long, device="cpu") + + # 5. Create response masks + complete_chosen_mask = torch.cat( + [ + torch.zeros(prompt_len, dtype=torch.bool, device="cpu"), + torch.ones(chosen_len, dtype=torch.bool, device="cpu"), + torch.zeros(rejected_len, dtype=torch.bool, device="cpu"), + ] + ) + complete_rejected_mask = torch.cat( + [ + torch.zeros(prompt_len, dtype=torch.bool, device="cpu"), + torch.zeros(chosen_len, dtype=torch.bool, device="cpu"), + torch.ones(rejected_len, dtype=torch.bool, device="cpu"), + ] + ) + + # Append all complete tensors to the pack + pack["tokens"].append(complete_tokens) + pack["labels"].append(complete_labels) + pack["document_ids"].append(complete_doc_ids) + pack["input_pos"].append(complete_input_pos) + pack["chosen_response_mask"].append(complete_chosen_mask) + pack["rejected_response_mask"].append(complete_rejected_mask) # Handle metrics if they exist in the sample if "metrics" in sample: pack["metrics"].extend(sample["metrics"]) - # Advance the document ID counter by 3 for the next DPO sample. + # Each DPO sample consists of 3 documents (prompt, chosen, rejected) return 3 def finalize_pack( self, pack: dict[str, list], target_tokens_per_pack: int, next_doc_id: int ) -> PackType: - current_size = len(pack["tokens"]) + """Finalizes pack by padding and concatenating tensor lists efficiently.""" + # Calculate current size from tensor list + current_size = sum(t.numel() for t in pack["tokens"]) if pack["tokens"] else 0 num_padding = target_tokens_per_pack - current_size + # Add padding tensors if needed if num_padding > 0: - pack["tokens"].extend([self.padding_idx] * num_padding) - pack["labels"].extend([self.ignore_idx] * num_padding) - pack["input_pos"].extend([0] * num_padding) - pack["chosen_response_mask"].extend([False] * num_padding) - pack["rejected_response_mask"].extend([False] * num_padding) - pack["document_ids"].extend([next_doc_id] * num_padding) - - # Add pct_of_tokens_padded metric - padding_metric = Metric( - dataset_name=self.dataset_name, - name="pct_of_tokens_padded", - value=round(num_padding * 100 / len(pack["tokens"]), 2), - agg_type=AggregationType.MEAN, - ) - pack["metrics"].append(padding_metric) + pack["tokens"].append( + torch.full( + (num_padding,), self.padding_idx, dtype=torch.long, device="cpu" + ) + ) + pack["labels"].append( + torch.full( + (num_padding,), self.ignore_idx, dtype=torch.long, device="cpu" + ) + ) + pack["document_ids"].append( + torch.full((num_padding,), next_doc_id, dtype=torch.long, device="cpu") + ) + pack["input_pos"].append( + torch.zeros(num_padding, dtype=torch.long, device="cpu") + ) + pack["chosen_response_mask"].append( + torch.zeros(num_padding, dtype=torch.bool, device="cpu") + ) + pack["rejected_response_mask"].append( + torch.zeros(num_padding, dtype=torch.bool, device="cpu") + ) - return { - "tokens": torch.tensor(pack["tokens"], dtype=torch.long), - "labels": torch.tensor(pack["labels"], dtype=torch.long), - "document_ids": torch.tensor(pack["document_ids"], dtype=torch.long), - "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), - "chosen_response_mask": torch.tensor( - pack["chosen_response_mask"], dtype=torch.bool - ), - "rejected_response_mask": torch.tensor( - pack["rejected_response_mask"], dtype=torch.bool - ), + # Add padding percentage metric + if target_tokens_per_pack > 0: + padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) + padding_metric = Metric( + dataset_name=self.dataset_name, + name="pct_of_tokens_padded", + value=padding_pct, + agg_type=AggregationType.MEAN, + ) + pack["metrics"].append(padding_metric) + + # Concatenate all tensor lists + result = { + "tokens": torch.cat(pack["tokens"]) if pack["tokens"] else torch.empty(0, dtype=torch.long), + "labels": torch.cat(pack["labels"]) if pack["labels"] else torch.empty(0, dtype=torch.long), + "document_ids": torch.cat(pack["document_ids"]) if pack["document_ids"] else torch.empty(0, dtype=torch.long), + "input_pos": torch.cat(pack["input_pos"]) if pack["input_pos"] else torch.empty(0, dtype=torch.long), + "chosen_response_mask": torch.cat(pack["chosen_response_mask"]) if pack["chosen_response_mask"] else torch.empty(0, dtype=torch.bool), + "rejected_response_mask": torch.cat(pack["rejected_response_mask"]) if pack["rejected_response_mask"] else torch.empty(0, dtype=torch.bool), "metrics": pack["metrics"], } + return result + def _mask_mod( self, b: int, From c8bfbb2dd094173206ab2a89bada28a5a52e78af Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 14:15:12 -0400 Subject: [PATCH 43/48] tests --- .../datasets/test_iterable_packed_dataset.py | 713 ++++++++++++++++++ 1 file changed, 713 insertions(+) create mode 100644 tests/torchtune/datasets/test_iterable_packed_dataset.py diff --git a/tests/torchtune/datasets/test_iterable_packed_dataset.py b/tests/torchtune/datasets/test_iterable_packed_dataset.py new file mode 100644 index 0000000000..a36aa639aa --- /dev/null +++ b/tests/torchtune/datasets/test_iterable_packed_dataset.py @@ -0,0 +1,713 @@ +import logging +from typing import Any, Dict, Iterator, List, Optional + +import pytest +import torch +from torch.utils.data import IterableDataset +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +# Import the classes we're testing +if _SUPPORTS_FLEX_ATTENTION: + from torchtune.datasets._iterable_packed import ( + IterablePackedDataset, + TextPacker, + DPOPacker, + Packer, + ) + from torchdata.stateful_dataloader import Stateful + from torchtune.data._collate import collate_packed + +# --- Test Fixtures --- + + +@pytest.fixture +def device(): + return "cuda" + + +class DummyTextDataset(IterableDataset): + """Dummy dataset that returns tensor-based samples.""" + def __init__(self, sample_sizes): + self._sample_sizes = sample_sizes + self._counter = 0 + + def __iter__(self): + # Reset counter for each new iteration + self._counter = 0 + for size in self._sample_sizes: + yield { + "tokens": torch.full((size,), self._counter, dtype=torch.long), + "labels": torch.full((size,), self._counter, dtype=torch.long), + } + self._counter += 1 + + +class StatefulDummyTextDataset(IterableDataset, Stateful): + """ + A dummy text dataset that is also stateful, allowing its iteration + progress to be saved and loaded. Returns tensor-based samples. + """ + + def __init__(self, sample_sizes: List[int]): + self.sample_sizes = sample_sizes + self._state_to_load: Optional[Dict[str, Any]] = None + # The state is the index of the *next* sample to be processed. + self._active_iterator_state: Dict[str, Any] = {"sample_idx": 0} + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + # This base generator yields all samples from the beginning. + def _base_iterator(): + for i, size in enumerate(self.sample_sizes): + self._active_iterator_state = {"sample_idx": i} + yield { + "tokens": torch.full((size,), i, dtype=torch.long), # Use sample index as token value + "labels": torch.full((size,), i, dtype=torch.long), + } + # After iterating, the next sample index is out of bounds + self._active_iterator_state = {"sample_idx": len(self.sample_sizes)} + + iterator = _base_iterator() + + # If resuming, fast-forward the iterator to the correct position. + if self._state_to_load: + start_idx = self._state_to_load.get("sample_idx", 0) + logging.info( + f"StatefulDummyTextDataset.__iter__(): Resuming. Fast-forwarding iterator to index {start_idx}." + ) + self._state_to_load = None + # Fast-forward the iterator to the sample index from the checkpoint. + for _ in range(start_idx): + next(iterator, None) # Consume and discard samples until the desired start point. + + yield from iterator + + def state_dict(self) -> Dict[str, Any]: + logging.info( + f"StatefulDummyTextDataset.state_dict(): current state is {self._active_iterator_state}" + ) + return self._active_iterator_state + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + logging.info( + f"StatefulDummyTextDataset.load_state_dict(): state to load is {state_dict}" + ) + self._state_to_load = state_dict + + +class DummyDPODataset(IterableDataset): + """Dummy DPO dataset that returns tensor-based samples.""" + def __init__(self, samples): + self._samples = samples + + def __iter__(self): + yield from self._samples + + +# --- Test Utilities --- + + +def create_dense_mask_from_mask_mod( + strategy: Packer, doc_ids: torch.Tensor +) -> torch.Tensor: + """ + Helper utility to generate a dense boolean attention mask from a + strategy's _mask_mod implementation for testing purposes. + """ + batch_size, seq_len = doc_ids.shape + device = doc_ids.device + dense_mask = torch.zeros( + batch_size, seq_len, seq_len, dtype=torch.bool, device=device + ) + for b in range(batch_size): + for q_idx in range(seq_len): + q_tensor = torch.tensor(q_idx, device=device) + for kv_idx in range(seq_len): + kv_tensor = torch.tensor(kv_idx, device=device) + # h (head index) is not used in current implementations + dense_mask[b, q_idx, kv_idx] = strategy._mask_mod( + b, 0, q_tensor, kv_tensor, doc_ids + ) + return dense_mask + +@pytest.fixture +def dpo_packer(): + return DPOPacker(padding_idx=999, ignore_idx=-100) + +def assert_pack_structure(pack, packer_type, target_tokens, expected_docs): + """ + Helper to validate pack structure consistently across tests. + + Args: + pack: The pack dictionary to validate + packer_type: Type of packer ("text" or "dpo") + target_tokens: Expected sequence length + expected_docs: Expected number of unique document IDs + """ + assert pack["tokens"].shape == (target_tokens,) + assert pack["labels"].shape == (target_tokens,) + assert pack["document_ids"].shape == (target_tokens,) + + if packer_type == "dpo": + assert "chosen_response_mask" in pack + assert "rejected_response_mask" in pack + # Verify that a token cannot be part of both the chosen and rejected response. + assert not (pack["chosen_response_mask"] & pack["rejected_response_mask"]).any() + + # Verify document boundaries + assert torch.unique(pack["document_ids"]).numel() == expected_docs + +def assert_attention_mask_properties(mask, doc_ids): + """ + Helper to validate attention mask properties for packed sequences. + + Verifies causal attention within documents and proper masking boundaries. + + Args: + mask: Attention mask tensor of shape (batch_size, seq_len, seq_len) + doc_ids: Document ID tensor of shape (batch_size, seq_len) + """ + batch_size, seq_len, _ = mask.shape + + # Verify causal property within documents + for b in range(batch_size): + for doc_id in torch.unique(doc_ids[b]): + doc_indices = (doc_ids[b] == doc_id).nonzero(as_tuple=True)[0] + if not doc_indices.numel(): + continue + # The mask for tokens within a document should be lower-triangular (causal). + doc_mask = mask[b][doc_indices, :][:, doc_indices] + is_causal = torch.all(doc_mask.tril() == doc_mask) + assert is_causal, f"Mask for doc {doc_id} in batch {b} is not causal." + + +# --- Test Classes --- + +@pytest.fixture +def text_packer(): + return TextPacker(padding_idx=999, ignore_idx=-100) + +class TestTextPacker: + """Test TextPacker methods, attention masks, and integration workflow""" + + def test_create_empty_pack(self, text_packer): + """Test empty pack creation for TextPacker""" + pack = text_packer.create_empty_pack() + expected = { + "tokens": [], + "labels": [], + "document_ids": [], + "input_pos": [], + "metrics": [], + } + assert pack == expected + + def test_get_sample_size(self, text_packer): + """Test sample size calculation for multiple TextPacker samples""" + samples = [ + {"tokens": torch.tensor([1, 2, 3]), "labels": torch.tensor([4, 5, 6])}, + {"tokens": torch.tensor([7]), "labels": torch.tensor([8])}, + {"tokens": torch.tensor([9, 10, 11, 12]), "labels": torch.tensor([13, 14, 15, 16])}, + ] + + expected_sizes = [3, 1, 4] + for sample, expected_size in zip(samples, expected_sizes): + assert text_packer.get_sample_size(sample) == expected_size + + def test_add_multiple_samples_to_pack(self, text_packer): + """Test adding multiple samples to same pack""" + pack = text_packer.create_empty_pack() + + samples = [ + {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])}, + {"tokens": torch.tensor([5, 6, 7]), "labels": torch.tensor([8, 9, 10])}, + {"tokens": torch.tensor([11]), "labels": torch.tensor([12])}, + ] + + # Add all samples + for i, sample in enumerate(samples): + docs_consumed = text_packer.add_sample_to_pack(pack, sample, next_doc_id=i) + assert docs_consumed == 1 + + # Verify pack contents + assert len(pack["tokens"]) == 3 + torch.testing.assert_close(pack["tokens"][0], torch.tensor([1, 2])) + torch.testing.assert_close(pack["tokens"][1], torch.tensor([5, 6, 7])) + torch.testing.assert_close(pack["tokens"][2], torch.tensor([11])) + torch.testing.assert_close(pack["document_ids"][0], torch.tensor([0, 0])) + torch.testing.assert_close(pack["document_ids"][1], torch.tensor([1, 1, 1])) + torch.testing.assert_close(pack["document_ids"][2], torch.tensor([2])) + + def test_finalize_pack_multiple_samples(self, text_packer): + """Test pack finalization with multiple samples and padding""" + pack = { + "tokens": [torch.tensor([1, 2]), torch.tensor([3, 4, 5])], + "labels": [torch.tensor([6, 7]), torch.tensor([8, 9, 10])], + "document_ids": [torch.tensor([0, 0]), torch.tensor([1, 1, 1])], + "input_pos": [torch.tensor([0, 1]), torch.tensor([0, 1, 2])], + "metrics": [], + } + + result = text_packer.finalize_pack(pack, target_tokens_per_pack=8, next_doc_id=2) + + expected_tokens = torch.tensor([1, 2, 3, 4, 5, 999, 999, 999]) + expected_labels = torch.tensor([6, 7, 8, 9, 10, -100, -100, -100]) + expected_doc_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2]) + expected_input_pos = torch.tensor([0, 1, 0, 1, 2, 0, 0, 0]) + + torch.testing.assert_close(result["tokens"], expected_tokens) + torch.testing.assert_close(result["labels"], expected_labels) + torch.testing.assert_close(result["document_ids"], expected_doc_ids) + torch.testing.assert_close(result["input_pos"], expected_input_pos) + + def test_text_causal_mask(self, device): + """Test standard causal masking for TextPacker""" + text_packer = TextPacker(padding_idx=0) + # One sample in batch: doc 0 (len 2), doc 1 (len 2), doc 2 (len 3) + doc_ids = torch.tensor([[0, 0, 1, 1, 2, 2, 2]], device=device) + + expected_mask = torch.tensor([ + # q k-> 0 1 2 3 4 5 6 + [1, 0, 0, 0, 0, 0, 0], # 0 + [1, 1, 0, 0, 0, 0, 0], # 1 + [0, 0, 1, 0, 0, 0, 0], # 2 + [0, 0, 1, 1, 0, 0, 0], # 3 + [0, 0, 0, 0, 1, 0, 0], # 4 + [0, 0, 0, 0, 1, 1, 0], # 5 + [0, 0, 0, 0, 1, 1, 1], # 6 + ], dtype=torch.bool, device=device).unsqueeze(0) + + actual_mask = create_dense_mask_from_mask_mod(text_packer, doc_ids) + torch.testing.assert_close(actual_mask, expected_mask) + + def test_text_packing_workflow_two_packs(self): + """Test complete text workflow that creates exactly 2 packs with multiple samples""" + # Design: Pack1=[3,2], Pack2=[4] to create 2 packs + sample_sizes = [3, 2, 4] + target_tokens = 6 + + dataset = DummyTextDataset(sample_sizes) + text_packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, packer=text_packer, target_tokens_per_pack=target_tokens + ) + + packs = list(packed_dataset) + assert len(packs) == 2 + + # Pack 1: samples 0(size 3) + 1(size 2) + padding(1) + pack1 = packs[0] + assert pack1["tokens"].shape == (target_tokens,) + assert (pack1["labels"] != -100).sum() == 5 # 3 + 2 real tokens + expected_tokens_1 = torch.tensor([0, 0, 0, 1, 1, 999]) + expected_doc_ids_1 = torch.tensor([0, 0, 0, 1, 1, 2]) + expected_input_pos_1 = torch.tensor([0, 1, 2, 0, 1, 0]) + torch.testing.assert_close(pack1["tokens"], expected_tokens_1) + torch.testing.assert_close(pack1["document_ids"], expected_doc_ids_1) + torch.testing.assert_close(pack1["input_pos"], expected_input_pos_1) + + # Pack 2: sample 2(size 4) + padding(2) - single sample pack + pack2 = packs[1] + assert (pack2["labels"] != -100).sum() == 4 # 4 real tokens + expected_tokens_2 = torch.tensor([2, 2, 2, 2, 999, 999]) + expected_doc_ids_2 = torch.tensor([0, 0, 0, 0, 1, 1]) + expected_input_pos_2 = torch.tensor([0, 1, 2, 3, 0, 0]) + torch.testing.assert_close(pack2["tokens"], expected_tokens_2) + torch.testing.assert_close(pack2["document_ids"], expected_doc_ids_2) + torch.testing.assert_close(pack2["input_pos"], expected_input_pos_2) + +class TestDPOPacker: + """Test DPOPacker methods, attention masks, and integration workflow""" + + def test_create_empty_pack(self, dpo_packer): + """Test empty pack creation for DPOPacker""" + pack = dpo_packer.create_empty_pack() + expected = { + "tokens": [], + "labels": [], + "document_ids": [], + "input_pos": [], + "chosen_response_mask": [], + "rejected_response_mask": [], + "metrics": [], + } + assert pack == expected + + def test_get_sample_size(self, dpo_packer): + """Test sample size calculation for multiple DPOPacker samples""" + samples = [ + { + "prompt_ids": torch.tensor([1, 2]), + "chosen_response_only_ids": torch.tensor([3, 4]), + "chosen_response_only_labels": torch.tensor([3, 4]), + "rejected_response_only_ids": torch.tensor([5, 6]), + "rejected_response_only_labels": torch.tensor([5, 6]), + }, + { + "prompt_ids": torch.tensor([7, 8, 9]), + "chosen_response_only_ids": torch.tensor([10, 11]), + "chosen_response_only_labels": torch.tensor([10, 11]), + "rejected_response_only_ids": torch.tensor([12, 13, 14, 15]), + "rejected_response_only_labels": torch.tensor([12, 13, 14, 15]), + }, + { + "prompt_ids": torch.tensor([16]), + "chosen_response_only_ids": torch.tensor([17, 18, 19]), + "chosen_response_only_labels": torch.tensor([17, 18, 19]), + "rejected_response_only_ids": torch.tensor([20, 21]), + "rejected_response_only_labels": torch.tensor([20, 21]), + } + ] + + expected_sizes = [6, 9, 6] # [2+2+2, 3+2+4, 1+3+2] + for sample, expected_size in zip(samples, expected_sizes): + assert dpo_packer.get_sample_size(sample) == expected_size + + def test_add_multiple_samples_to_pack(self, dpo_packer): + """Test adding multiple DPO samples to pack""" + pack = dpo_packer.create_empty_pack() + samples = [ + { + "prompt_ids": torch.tensor([1, 2]), + "chosen_response_only_ids": torch.tensor([3, 4]), + "chosen_response_only_labels": torch.tensor([3, 4]), + "rejected_response_only_ids": torch.tensor([5, 6]), + "rejected_response_only_labels": torch.tensor([5, 6]), + }, + { + "prompt_ids": torch.tensor([7, 8]), + "chosen_response_only_ids": torch.tensor([9]), + "chosen_response_only_labels": torch.tensor([9]), + "rejected_response_only_ids": torch.tensor([10, 11]), + "rejected_response_only_labels": torch.tensor([10, 11]), + } + ] + + # Add all samples + for i, sample in enumerate(samples): + docs_consumed = dpo_packer.add_sample_to_pack(pack, sample, next_doc_id=i*3) + assert docs_consumed == 3 # prompt + chosen + rejected + + # Verify pack contents + assert len(pack["tokens"]) == 2 + # First sample: [1,2,3,4,5,6] + torch.testing.assert_close(pack["tokens"][0], torch.tensor([1, 2, 3, 4, 5, 6])) + torch.testing.assert_close( + pack["labels"][0], torch.tensor([-100, -100, 3, 4, 5, 6]) + ) + torch.testing.assert_close( + pack["document_ids"][0], torch.tensor([0, 0, 1, 1, 2, 2]) + ) + torch.testing.assert_close( + pack["chosen_response_mask"][0], + torch.tensor([False, False, True, True, False, False]) + ) + torch.testing.assert_close( + pack["rejected_response_mask"][0], + torch.tensor([False, False, False, False, True, True]) + ) + + # Second sample: [7,8,9,10,11] + torch.testing.assert_close(pack["tokens"][1], torch.tensor([7, 8, 9, 10, 11])) + torch.testing.assert_close( + pack["labels"][1], torch.tensor([-100, -100, 9, 10, 11]) + ) + torch.testing.assert_close( + pack["document_ids"][1], torch.tensor([3, 3, 4, 5, 5]) + ) + torch.testing.assert_close( + pack["chosen_response_mask"][1], + torch.tensor([False, False, True, False, False]) + ) + torch.testing.assert_close( + pack["rejected_response_mask"][1], + torch.tensor([False, False, False, True, True]) + ) + + def test_finalize_pack_multiple_dpo_samples(self, dpo_packer): + """Test DPO pack finalization with multiple samples and padding.""" + pack = dpo_packer.create_empty_pack() + + sample1 = { + "prompt_ids": torch.tensor([1, 2]), + "chosen_response_only_ids": torch.tensor([3, 4]), + "chosen_response_only_labels": torch.tensor([3, 4]), + "rejected_response_only_ids": torch.tensor([5, 6]), + "rejected_response_only_labels": torch.tensor([5, 6]), + } + dpo_packer.add_sample_to_pack(pack, sample1, next_doc_id=0) # docs 0, 1, 2 + + sample2 = { + "prompt_ids": torch.tensor([7]), + "chosen_response_only_ids": torch.tensor([8]), + "chosen_response_only_labels": torch.tensor([8]), + "rejected_response_only_ids": torch.tensor([9, 10]), + "rejected_response_only_labels": torch.tensor([9, 10]), + } + dpo_packer.add_sample_to_pack(pack, sample2, next_doc_id=3) # docs 3, 4, 5 + + # Total tokens = 6 (sample1) + 4 (sample2) = 10 + result = dpo_packer.finalize_pack(pack, target_tokens_per_pack=12, next_doc_id=6) + + expected_tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 999, 999]) + expected_labels = torch.tensor([-100, -100, 3, 4, 5, 6, -100, 8, 9, 10, -100, -100]) + expected_doc_ids = torch.tensor([0, 0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6]) + expected_chosen_mask = torch.tensor([False, False, True, True, False, False, False, True, False, False, False, False]) + expected_rejected_mask = torch.tensor([False, False, False, False, True, True, False, False, True, True, False, False]) + + torch.testing.assert_close(result["tokens"], expected_tokens) + torch.testing.assert_close(result["labels"], expected_labels) + torch.testing.assert_close(result["document_ids"], expected_doc_ids) + torch.testing.assert_close(result["chosen_response_mask"], expected_chosen_mask) + torch.testing.assert_close(result["rejected_response_mask"], expected_rejected_mask) + + def test_dpo_specialized_mask(self, device): + """ + Verify the correctness of the DPO attention mask by manually constructing + the expected mask for a batch containing multiple, different samples. + """ + dpo_packer = DPOPacker(padding_idx=0) + + # Batch contains two packs of different token layouts. + # Pack 1: Two DPO samples. P(1), C(2), R(1) | P(1), C(1), R(1) + # Doc IDs: P_A=0, C_A=1, R_A=2 | P_B=3, C_B=4, R_B=5 + doc_ids_1 = torch.tensor([0, 1, 1, 2, 3, 4, 5], device=device) + + # Pack 2: One DPO sample, then padding. P(2), C(2), R(1), Pad(2) + # Doc IDs: P_A=0, C_A=1, R_A=2 | Padding=3 + doc_ids_2 = torch.tensor([0, 0, 1, 1, 2, 3, 3], device=device) + batch_doc_ids = torch.stack([doc_ids_1, doc_ids_2]) + + # --- Manually create the expected mask for Pack 1 --- + mask1 = torch.tensor([ + # k_idx -> P C C R P C R (k_idx) + [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) can see self + [1, 1, 0, 0, 0, 0, 0], # q=1 (C_A) can see P_A and self (causal) + [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) can see P_A and C_A (causal) + [1, 0, 0, 1, 0, 0, 0], # q=3 (R_A) can see P_A and self + [0, 0, 0, 0, 1, 0, 0], # q=4 (P_B) can see self + [0, 0, 0, 0, 1, 1, 0], # q=5 (C_B) can see P_B and self + [0, 0, 0, 0, 1, 0, 1], # q=6 (R_B) can see P_B and self + ], dtype=torch.bool, device=device) + + # --- Manually create the expected mask for Pack 2 --- + mask2 = torch.tensor([ + # q_idx, P P C C R Pad Pad(k_idx) + [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) + [1, 1, 0, 0, 0, 0, 0], # q=1 (P_A) + [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) + [1, 1, 1, 1, 0, 0, 0], # q=3 (C_A) + [1, 1, 0, 0, 1, 0, 0], # q=4 (R_A) + [0, 0, 0, 0, 0, 1, 0], # q=5 (Pad) + [0, 0, 0, 0, 0, 1, 1], # q=6 (Pad) + ], dtype=torch.bool, device=device) + + expected_mask = torch.stack([mask1, mask2]) + + actual_mask = create_dense_mask_from_mask_mod(dpo_packer, batch_doc_ids) + torch.testing.assert_close(actual_mask, expected_mask) + + def test_dpo_packing_workflow_two_packs(self): + """Test complete DPO workflow that creates exactly 2 packs with multiple samples""" + samples = [ + { # Sample 0: total 4 tokens (1+1+2) + "prompt_ids": torch.tensor([1]), + "chosen_response_only_ids": torch.tensor([2]), + "chosen_response_only_labels": torch.tensor([2]), + "rejected_response_only_ids": torch.tensor([3, 4]), + "rejected_response_only_labels": torch.tensor([3, 4]), + }, + { # Sample 1: total 5 tokens (2+1+2) + "prompt_ids": torch.tensor([5, 6]), + "chosen_response_only_ids": torch.tensor([7]), + "chosen_response_only_labels": torch.tensor([7]), + "rejected_response_only_ids": torch.tensor([8, 9]), + "rejected_response_only_labels": torch.tensor([8, 9]), + }, + { # Sample 2: total 6 tokens (2+2+2) + "prompt_ids": torch.tensor([10, 11]), + "chosen_response_only_ids": torch.tensor([12, 13]), + "chosen_response_only_labels": torch.tensor([12, 13]), + "rejected_response_only_ids": torch.tensor([14, 15]), + "rejected_response_only_labels": torch.tensor([14, 15]), + } + ] + + dataset = DummyDPODataset(samples) + dpo_packer = DPOPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, packer=dpo_packer, target_tokens_per_pack=10 + ) + + packs = list(packed_dataset) + assert len(packs) == 2 # Pack1: samples 0+1 (4+5=9), Pack2: sample 2 (6) + + # Pack 1: samples 0+1 (9 tokens) + padding (1) + pack1 = packs[0] + assert pack1["tokens"].shape == (10,) + assert "chosen_response_mask" in pack1 + assert "rejected_response_mask" in pack1 + non_padding_1 = (pack1["tokens"] != 999).sum() + assert non_padding_1 == 9 + + # Pack 2: sample 2 (6 tokens) + padding (4) + pack2 = packs[1] + non_padding_2 = (pack2["tokens"] != 999).sum() + assert non_padding_2 == 6 + + # Verify masks are mutually exclusive + chosen_and_rejected_1 = pack1["chosen_response_mask"] & pack1["rejected_response_mask"] + chosen_and_rejected_2 = pack2["chosen_response_mask"] & pack2["rejected_response_mask"] + assert not chosen_and_rejected_1.any() + assert not chosen_and_rejected_2.any() + +class TestIterablePackedDataset: + """Test IterablePackedDataset functionality - buffer efficiency, checkpointing, edge cases""" + + def test_buffer_efficiency(self): + """Test buffer improves packing efficiency""" + # Test case where buffer helps vs hurts - order matters for first-fit + sample_sizes = [3, 4, 1, 2] # Total 10 tokens + target_tokens = 6 + + # With large buffer: can see all samples and pick best fit [3,1,2], [4] + dataset1 = DummyTextDataset(sample_sizes) + packer1 = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset1 = IterablePackedDataset( + dataset=dataset1, packer=packer1, + target_tokens_per_pack=target_tokens, buffer_size=10 + ) + packs_buffered = list(packed_dataset1) + + # With small buffer: greedy first-fit [3], [4,1], [2] + dataset2 = DummyTextDataset(sample_sizes) + packer2 = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset2 = IterablePackedDataset( + dataset=dataset2, packer=packer2, + target_tokens_per_pack=target_tokens, buffer_size=1 + ) + packs_unbuffered = list(packed_dataset2) + + # Buffer should create fewer packs (more efficient) + assert len(packs_buffered) < len(packs_unbuffered) + assert len(packs_buffered) == 2 # [3,1,2], [4] + assert len(packs_unbuffered) == 3 # [3], [4,1], [2] + + # Verify both preserve all tokens + total_buffered = sum((p["labels"] != -100).sum().item() for p in packs_buffered) + total_unbuffered = sum((p["labels"] != -100).sum().item() for p in packs_unbuffered) + assert total_buffered == total_unbuffered == sum(sample_sizes) + + def test_oversized_sample_dropping(self): + """Test that oversized samples are dropped""" + sample_sizes = [3, 10, 2, 8, 1] # 10 and 8 are oversized for target=6 + target_tokens = 6 + + dataset = DummyTextDataset(sample_sizes) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, packer=packer, target_tokens_per_pack=target_tokens + ) + + packs = list(packed_dataset) + + # Only samples 3, 2, 1 should be packed (oversized 10, 8 dropped) + total_packed_tokens = sum((p["labels"] != -100).sum().item() for p in packs) + expected_tokens = 3 + 2 + 1 # Only non-oversized samples + assert total_packed_tokens == expected_tokens + + # Should create 2 packs: [3, 2], [1] + assert len(packs) == 2 + + def test_checkpoint_and_resume(self): + """Test checkpointing and resumption functionality""" + sample_sizes = [3, 2, 5, 4, 1, 6] # Total 21 tokens + target_tokens_per_pack = 6 + + # First run: iterate partially + dataset1 = StatefulDummyTextDataset(sample_sizes) + packer1 = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset1 = IterablePackedDataset( + dataset=dataset1, packer=packer1, + target_tokens_per_pack=target_tokens_per_pack, buffer_size=4 + ) + + # Get first pack and save state + packed_iterator1 = iter(packed_dataset1) + pack1_partial = next(packed_iterator1) + state = packed_dataset1.state_dict() + + # Second run: resume from checkpoint + dataset2 = StatefulDummyTextDataset(sample_sizes) + packer2 = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset2 = IterablePackedDataset( + dataset=dataset2, packer=packer2, + target_tokens_per_pack=target_tokens_per_pack, buffer_size=4 + ) + packed_dataset2.load_state_dict(state) + + resumed_packs = list(packed_dataset2) + + # Verify resumption worked (buffer contents are lost, so some samples skipped) + assert len(resumed_packs) >= 1 + total_resumed_tokens = sum((p["labels"] != -100).sum().item() for p in resumed_packs) + assert total_resumed_tokens > 0 + + # Verify that together, first pack + resumed packs contain reasonable amount of data + # (not all data since buffer loss causes some samples to be skipped) + total_first_tokens = (pack1_partial["labels"] != -100).sum().item() + total_all_tokens = total_first_tokens + total_resumed_tokens + assert total_all_tokens < sum(sample_sizes) # Some data lost due to buffer + + def test_multiple_iterations_same_dataset(self): + """Test that multiple iterations over same packed dataset work correctly""" + sample_sizes = [2, 3, 1] + dataset = DummyTextDataset(sample_sizes) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, packer=packer, target_tokens_per_pack=4 + ) + + # First iteration + packs1 = list(packed_dataset) + # Second iteration should produce same result + packs2 = list(packed_dataset) + + assert len(packs1) == len(packs2) + for p1, p2 in zip(packs1, packs2): + torch.testing.assert_close(p1["tokens"], p2["tokens"]) + torch.testing.assert_close(p1["document_ids"], p2["document_ids"]) + + @pytest.mark.parametrize( + "sample_sizes,target_tokens,buffer_size,expected_packs,scenario", [ + ([3, 2, 4], 8, 10, 2, "basic_packing"), # Pack1: [3,2]+pad, Pack2: [4]+pad + ([4, 3], 8, 10, 1, "partial_final_pack"), # Pack1: [4,3]+pad + ([], 8, 10, 0, "empty_dataset"), + ([5], 10, 10, 1, "single_sample"), + ([5, 5, 5], 5, 10, 3, "exact_fit"), + ([2, 3, 1], 5, 1, 2, "small_target_and_buffer"), # Pack1: [2,3], Pack2: [1] + ] + ) + def test_scenarios(self, sample_sizes, target_tokens, buffer_size, expected_packs, scenario): + """Parametrized edge case testing""" + dataset = DummyTextDataset(sample_sizes) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, packer=packer, + target_tokens_per_pack=target_tokens, buffer_size=buffer_size + ) + + packs = list(packed_dataset) + assert len(packs) == expected_packs, f"Failed scenario: {scenario}" + + # Verify output format consistency for all scenarios + for pack in packs: + assert pack["tokens"].shape == (target_tokens,) + assert pack["labels"].shape == (target_tokens,) + assert pack["document_ids"].shape == (target_tokens,) + assert pack["input_pos"].shape == (target_tokens,) + + # Verify no token loss + if sample_sizes: # Skip for empty dataset + total_packed = sum((p["labels"] != -100).sum().item() for p in packs) + assert total_packed == sum(sample_sizes) \ No newline at end of file From fd41842849e8574742808bda8499084381573239 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 14:15:16 -0400 Subject: [PATCH 44/48] docs --- torchtune/datasets/_iterable_packed.py | 244 +++++++++++++++++-------- 1 file changed, 167 insertions(+), 77 deletions(-) diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index dbbeb3d54c..5a03c8569c 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -22,13 +22,42 @@ logger = logging.getLogger(__name__) + SampleType = TypeVar("SampleType") PackType = dict[str, torch.Tensor | list[Metric]] class Packer(ABC, Generic[SampleType]): """ - Packer to be used in IterablePackedDataset and with FlexAttention. + An abstract base class that defines the logic for packing samples into a + fixed-size sequence. It is used by `IterablePackedDataset` to handle + different data formats (e.g., standard text, DPO pairs). + + A `Packer` is responsible for: + 1. Defining how to extract the token count from a raw sample. + 2. Specifying how a raw sample is deconstructed into tensors and added + to a "pack". + 3. Finalizing a pack by padding it to the target sequence length. + 4. Generating the appropriate attention mask for the packed format. + + This modular design allows `IterablePackedDataset` to remain agnostic to + the data format and packing strategy. + + Args: + padding_idx (int): The index of the padding token. + ignore_idx (int): The index to use for labels that should be + ignored in the loss calculation (e.g., padding tokens). + + Example: + >>> packer = TextPacker(padding_idx=0, ignore_idx=-100) + >>> pack = packer.create_empty_pack() + >>> sample = {"tokens": torch.tensor([1, 2, 3]), "labels": torch.tensor([4, 5, 6])} + >>> packer.add_sample_to_pack(pack, sample, next_doc_id=0) + >>> final_pack = packer.finalize_pack(pack, target_tokens_per_pack=5, next_doc_id=1) + >>> mask = packer.create_block_mask(final_pack["document_ids"].unsqueeze(0), device="cpu") + + Raises: + RuntimeError: If FlexAttention is not supported in the current environment. """ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): @@ -53,14 +82,15 @@ def set_dataset_name(self, dataset_name: str) -> None: @abstractmethod def create_empty_pack(self) -> dict[str, list[Any]]: """ - Creates an empty pack with lists that will hold tensors. - + Creates an empty pack structure for accumulating samples. + Returns: - dict[str, list[Any]]: An empty dictionary with lists as values. - + dict[str, list[Any]]: An empty structure that can accumulate sample data + and be converted to tensors by finalize_pack(). + Example: - self.create_empty_pack() - >>> {"tokens": [], "labels": [], "document_ids": [], "input_pos": []} + >>> packer.create_empty_pack() + {"tokens": [], "labels": []} """ pass @@ -98,14 +128,22 @@ def add_sample_to_pack( int: The number of new documents that were added to the pack. Example: - pack = {"tokens": [tensor([1, 2])], "labels": [tensor([3, 4])], ...} - sample = {"tokens": tensor([5, 6]), "labels": tensor([7, 8])} - added_docs = self.add_sample_to_pack(pack, sample, next_doc_id=1) - print(pack) - >>> {"tokens": [tensor([1, 2]), tensor([5, 6])], - "labels": [tensor([3, 4]), tensor([7, 8])], ...} - print(added_docs) - >>> 1 + >>> packer = TextPacker(padding_idx=0, ignore_idx=-100) + >>> pack = {"tokens": [torch.tensor([1, 2])], + ... "labels": [torch.tensor([3, 4])], + ... "document_ids": [torch.tensor([0, 0])], + ... "input_pos": [torch.tensor([0, 1])], + ... "metrics": []} + >>> sample = {"tokens": torch.tensor([5, 6]), + ... "labels": torch.tensor([7, 8])} + >>> added_docs = packer.add_sample_to_pack(pack, sample, next_doc_id=1) + >>> print(pack) + {"tokens": [torch.tensor([1, 2]), torch.tensor([5, 6])], + "labels": [torch.tensor([3, 4]), torch.tensor([7, 8])], + "document_ids": [torch.tensor([0, 0]), torch.tensor([1, 1])], + "input_pos": [torch.tensor([0, 1]), torch.tensor([0, 1])], "metrics": []} + >>> print(added_docs) + 1 """ pass @@ -125,15 +163,19 @@ def finalize_pack( PackType: The finalized pack with concatenated tensors. Example: - pack = {"tokens": [tensor([1, 2])], "labels": [tensor([3, 4])], ...} - target_tokens_per_pack = 4 - next_doc_id = 1 - self.padding_idx = 999 - self.ignore_idx = -100 - - self.finalize_pack(pack, target_tokens_per_pack, next_doc_id) - >>> {"tokens": tensor([1, 2, 999, 999]), - "labels": tensor([3, 4, -100, -100]), ...} + >>> packer = TextPacker(padding_idx=999, ignore_idx=-100) + >>> pack = {"tokens": [torch.tensor([1, 2])], + ... "labels": [torch.tensor([3, 4])], + ... "document_ids": [torch.tensor([0, 0])], + ... "input_pos": [torch.tensor([0, 1])], "metrics": []} + >>> target_tokens_per_pack = 4 + >>> next_doc_id = 1 + >>> result = packer.finalize_pack(pack, target_tokens_per_pack, next_doc_id) + >>> print(result) + {"tokens": torch.tensor([1, 2, 999, 999]), + "labels": torch.tensor([3, 4, -100, -100]), + "document_ids": torch.tensor([0, 0, 1, 1]), + "input_pos": torch.tensor([0, 1, 0, 0]), "metrics": [...]} """ pass @@ -166,9 +208,22 @@ def _mask_mod( """ pass - def create_block_mask(self, batch_document_ids, device): + def create_block_mask( + self, batch_document_ids: torch.Tensor, device: torch.device + ) -> torch.Tensor: """ - Creates a block-causal attention mask using FlexAttention. + Creates a block-causal attention mask for packed sequences using FlexAttention. + + The mask ensures tokens only attend to appropriate positions based on the + packer's specific attention pattern (e.g., within same document for TextPacker, + cross-attention for DPOPacker). + + Args: + batch_document_ids (torch.Tensor): Document IDs of shape (batch_size, seq_len) + device (torch.device): Device to create the mask on + + Returns: + torch.Tensor: Block mask for FlexAttention """ batch_size, seq_len = batch_document_ids.shape doc_ids = batch_document_ids.to(device) @@ -187,15 +242,32 @@ class IterablePackedDataset( TuneIterableDataset[PackType], Stateful, Generic[SampleType] ): """ - IterablePackedDataset takes any TuneIterableDataset and a Packer, packs documents until - the 'target_tokens_per_pack' is reached and yields a dictionary of tensors. + Wraps a `TuneIterableDataset` to combine multiple samples into a single, + fixed-size "pack". This is highly efficient for training as it minimizes + padding and ensures consistent batch shapes. + + The packing process works as follows: + 1. It fetches samples from the underlying `dataset` and stores them in + an internal `buffer`. + 2. It uses a "best-fit" approach to select samples from the buffer that + can fill a pack up to `target_tokens_per_pack`. + 3. The `packer` handles the logic for deconstructing samples, creating + metadata (like document IDs and attention masks), and padding the + final pack. + + This dataset is stateful and supports checkpointing (relies on child dataset to be stateful), + allowing training to be resumed seamlessly. Args: - dataset (TuneIterableDataset[SampleType]): The TuneIterableDataset to pack. - packer (Packer[SampleType]): The Packer specific to the dataset format (e.g. text, DPO, etc.). - target_tokens_per_pack (int): The target number of tokens per pack. - buffer_size (int): The size of the buffer to use for packing. - dataset_name (str): The name of the dataset. If None, a defaults to IterablePackedDataset. + dataset (TuneIterableDataset[SampleType]): The `TuneIterableDataset` to pack. + packer (Packer[SampleType]): The `Packer` that defines the packing + strategy for the dataset format (e.g. `TextPacker`). + target_tokens_per_pack (int): The target number of tokens for each pack. + buffer_size (int): The number of samples to buffer for finding the + best fit. A larger buffer may improve packing efficiency at the + cost of memory. Buffer samples are discarded if resuming from a checkpoint. + Default is 100. + dataset_name (str): The name of the dataset, used for metrics. """ def __init__( @@ -203,7 +275,7 @@ def __init__( dataset: TuneIterableDataset[SampleType], packer: Packer[SampleType], target_tokens_per_pack: int, - buffer_size: int = 50, + buffer_size: int = 100, dataset_name: str = "IterablePackedDataset", ): self.dataset = dataset @@ -219,6 +291,7 @@ def __init__( @property def dataset_name(self) -> str: + """Returns the dataset name, used for metrics tracking.""" return self._dataset_name def _reset_packer_state(self) -> None: @@ -351,7 +424,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[PackType]: def __iter__(self) -> Iterator[PackType]: if not isinstance(self.dataset, Iterable): - raise TypeError("Dataset is not iterable.") + raise TypeError("Dataset is not an iterable") if not self._resuming: self._reset_packer_state() @@ -413,14 +486,20 @@ class TextPacker(Packer[dict[str, torch.Tensor]]): - Each sample is treated as a separate document. - `input_pos` restarts from 0 for each sample. - `document_ids` assigns a unique ID to each sample for masking. + + Args: + padding_idx (int): The index of the padding token. + ignore_idx (int): The index for ignored labels. """ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): - super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + super().__init__(padding_idx, ignore_idx) self.dataset_name = "packed_dataset" # Default name def set_dataset_name(self, dataset_name: str) -> None: - """Set the dataset name for metrics.""" + """ + Sets the dataset name on the packer. This is used for logging metrics. + """ self.dataset_name = dataset_name def create_empty_pack(self) -> dict[str, list]: @@ -523,16 +602,12 @@ def _mask_mod( return causal_mask & document_mask -# NOTE: For demonstration purposes only. - - class DPOPacker(Packer[dict[str, torch.Tensor]]): """ - Packer for packing DPO samples with a shared prompt. It packs a DPO - sample as three logical documents: a shared prompt, a chosen response, - and a rejected response. This structure is encoded in the `document_ids` - metadata, allowing the packer to build the correct attention pattern - (e.g., both responses can attend to the prompt, but not to each other). + Packer for Direct Preference Optimization (DPO). It packs a DPO sample + as three logical documents: a shared prompt, a chosen response, and a rejected response. + It encodes the attention mask with shared prompt, so that both responses can attend to the same prompt, + without repetition, but not to each other. ASSUMPTION: The input DPO sample dict contains pre-tokenized tensors: - "prompt_ids" @@ -540,14 +615,20 @@ class DPOPacker(Packer[dict[str, torch.Tensor]]): - "chosen_response_only_labels" - "rejected_response_only_ids" - "rejected_response_only_labels" + + Args: + padding_idx (int): The index of the padding token. + ignore_idx (int): The index for ignored labels. """ def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): - super().__init__(padding_idx=padding_idx, ignore_idx=ignore_idx) + super().__init__(padding_idx, ignore_idx) self.dataset_name = "packed_dataset" # Default name def set_dataset_name(self, dataset_name: str) -> None: - """Set the dataset name for metrics.""" + """ + Sets the dataset name on the packer. This is used for logging metrics. + """ self.dataset_name = dataset_name def create_empty_pack(self) -> dict[str, list]: @@ -573,19 +654,25 @@ def get_sample_size(self, sample: dict[str, torch.Tensor]) -> int: def add_sample_to_pack( self, pack: dict[str, list], sample: dict[str, torch.Tensor], next_doc_id: int ) -> int: - """Adds a DPO sample by building complete tensors and appending to pack lists.""" - prompt_ids = sample["prompt_ids"] - chosen_ids = sample["chosen_response_only_ids"] - rejected_ids = sample["rejected_response_only_ids"] - - prompt_len = prompt_ids.numel() - chosen_len = chosen_ids.numel() - rejected_len = rejected_ids.numel() - - # 1. Add Shared Prompt, Chosen Response, and Rejected Response tokens - complete_tokens = torch.cat([prompt_ids, chosen_ids, rejected_ids]) - # 2. Add labels for all three parts - complete_labels = torch.cat( + """ + Adds a DPO sample to a pack. Each DPO sample consists of three parts + (prompt, chosen, rejected), and each part is assigned its own document ID. + """ + prompt_len = sample["prompt_ids"].numel() + chosen_len = sample["chosen_response_only_ids"].numel() + rejected_len = sample["rejected_response_only_ids"].numel() + + # 1. Concatenate tokens: [prompt, chosen_response, rejected_response] + tokens = torch.cat( + [ + sample["prompt_ids"], + sample["chosen_response_only_ids"], + sample["rejected_response_only_ids"], + ] + ) + + # 2. Create labels: [ignore_idx for prompt, chosen_labels, rejected_labels] + labels = torch.cat( [ torch.full( (prompt_len,), self.ignore_idx, dtype=torch.long), @@ -595,7 +682,7 @@ def add_sample_to_pack( ) # 3. Create document IDs: prompt(next_doc_id), chosen(next_doc_id+1), rejected(next_doc_id+2) - complete_doc_ids = torch.cat( + document_ids = torch.cat( [ torch.full( (prompt_len,), next_doc_id, dtype=torch.long), @@ -607,18 +694,18 @@ def add_sample_to_pack( ) # 4. Create input positions (restarts from 0 for each DPO sample) - total_len = complete_tokens.numel() - complete_input_pos = torch.arange(total_len, dtype=torch.long, device="cpu") + total_len = tokens.numel() + input_pos = torch.arange(total_len, dtype=torch.long, device="cpu") # 5. Create response masks - complete_chosen_mask = torch.cat( + chosen_response_mask = torch.cat( [ torch.zeros(prompt_len, dtype=torch.bool, device="cpu"), torch.ones(chosen_len, dtype=torch.bool, device="cpu"), torch.zeros(rejected_len, dtype=torch.bool, device="cpu"), ] ) - complete_rejected_mask = torch.cat( + rejected_response_mask = torch.cat( [ torch.zeros(prompt_len, dtype=torch.bool, device="cpu"), torch.zeros(chosen_len, dtype=torch.bool, device="cpu"), @@ -627,12 +714,12 @@ def add_sample_to_pack( ) # Append all complete tensors to the pack - pack["tokens"].append(complete_tokens) - pack["labels"].append(complete_labels) - pack["document_ids"].append(complete_doc_ids) - pack["input_pos"].append(complete_input_pos) - pack["chosen_response_mask"].append(complete_chosen_mask) - pack["rejected_response_mask"].append(complete_rejected_mask) + pack["tokens"].append(tokens) + pack["labels"].append(labels) + pack["document_ids"].append(document_ids) + pack["input_pos"].append(input_pos) + pack["chosen_response_mask"].append(chosen_response_mask) + pack["rejected_response_mask"].append(rejected_response_mask) # Handle metrics if they exist in the sample if "metrics" in sample: @@ -708,21 +795,24 @@ def _mask_mod( ) -> torch.Tensor: """ Mask logic for DPO. - - Causal self-attention within the same document. - - Cross-attention from response tokens (chosen/rejected) to their - corresponding prompt tokens. + - Causal self-attention within each document (prompt, chosen, rejected) + - Cross-attention: response tokens can attend to their prompt (shared for both responses) """ + # (batch_size, seq_len) q_doc = doc_ids[b, q_idx] kv_doc = doc_ids[b, kv_idx] # 1. Document-level Causal self-attention is_same_doc = q_doc == kv_doc - self_attention_mask = is_same_doc & (q_idx >= kv_idx) + is_causal = is_same_doc & (q_idx >= kv_idx) # 2. Cross-attention from response to prompt + # For a given query token, find the document ID of its corresponding prompt. + # Since each DPO sample consists of 3 documents (prompt, chosen, rejected), + # this maps q_doc to the base ID of its group (e.g., 4 -> 3, 5 -> 3). q_prompt_doc_id = (q_doc // 3) * 3 kv_is_part_of_q_prompt = kv_doc == q_prompt_doc_id q_is_response = (q_doc % 3) > 0 - cross_attention_mask = q_is_response & kv_is_part_of_q_prompt + is_cross_attention = q_is_response & kv_is_part_of_q_prompt - return self_attention_mask | cross_attention_mask + return is_causal | is_cross_attention From 734128ef93a9fda49553aacf43c212fe0bc817ae Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 12:27:28 -0700 Subject: [PATCH 45/48] tests + lint pass --- recipes/full_finetune_distributed.py | 2 +- .../datasets/test_iterable_packed_dataset.py | 424 ++++++++++++------ torchtune/data/_collate.py | 12 +- torchtune/datasets/__init__.py | 5 +- torchtune/datasets/_iterable_packed.py | 141 +++--- 5 files changed, 372 insertions(+), 212 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index c4ad8886f5..e79baf26fd 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -24,8 +24,8 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.datasets import InterleavedDataset, IterablePackedDataset from torchtune.data.metrics import MetricsAggregator +from torchtune.datasets import InterleavedDataset, IterablePackedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils diff --git a/tests/torchtune/datasets/test_iterable_packed_dataset.py b/tests/torchtune/datasets/test_iterable_packed_dataset.py index a36aa639aa..565cef18af 100644 --- a/tests/torchtune/datasets/test_iterable_packed_dataset.py +++ b/tests/torchtune/datasets/test_iterable_packed_dataset.py @@ -1,23 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Iterator, Optional import pytest import torch from torch.utils.data import IterableDataset -from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchdata.stateful_dataloader import Stateful + +# from torchtune.data._collate import collate_packed +from torchtune.datasets._iterable_base import DatasetInfo + +from torchtune.datasets._iterable_packed import ( + DPOPacker, + IterablePackedDataset, + Packer, + PackType, + TextPacker, +) from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION -# Import the classes we're testing -if _SUPPORTS_FLEX_ATTENTION: - from torchtune.datasets._iterable_packed import ( - IterablePackedDataset, - TextPacker, - DPOPacker, - Packer, - ) - from torchdata.stateful_dataloader import Stateful - from torchtune.data._collate import collate_packed - # --- Test Fixtures --- @@ -28,10 +34,16 @@ def device(): class DummyTextDataset(IterableDataset): """Dummy dataset that returns tensor-based samples.""" + def __init__(self, sample_sizes): self._sample_sizes = sample_sizes self._counter = 0 + @property + def info(self) -> DatasetInfo: + """Returns dataset information.""" + return DatasetInfo(name="DummyTextDataset", weight=1.0, children=()) + def __iter__(self): # Reset counter for each new iteration self._counter = 0 @@ -49,19 +61,26 @@ class StatefulDummyTextDataset(IterableDataset, Stateful): progress to be saved and loaded. Returns tensor-based samples. """ - def __init__(self, sample_sizes: List[int]): + def __init__(self, sample_sizes: list[int]): self.sample_sizes = sample_sizes - self._state_to_load: Optional[Dict[str, Any]] = None + self._state_to_load: Optional[dict[str, Any]] = None # The state is the index of the *next* sample to be processed. - self._active_iterator_state: Dict[str, Any] = {"sample_idx": 0} + self._active_iterator_state: dict[str, Any] = {"sample_idx": 0} + + @property + def info(self) -> DatasetInfo: + """Returns dataset information.""" + return DatasetInfo(name="StatefulDummyTextDataset", weight=1.0, children=()) - def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: # This base generator yields all samples from the beginning. def _base_iterator(): for i, size in enumerate(self.sample_sizes): self._active_iterator_state = {"sample_idx": i} yield { - "tokens": torch.full((size,), i, dtype=torch.long), # Use sample index as token value + "tokens": torch.full( + (size,), i, dtype=torch.long + ), # Use sample index as token value "labels": torch.full((size,), i, dtype=torch.long), } # After iterating, the next sample index is out of bounds @@ -78,17 +97,19 @@ def _base_iterator(): self._state_to_load = None # Fast-forward the iterator to the sample index from the checkpoint. for _ in range(start_idx): - next(iterator, None) # Consume and discard samples until the desired start point. + next( + iterator, None + ) # Consume and discard samples until the desired start point. yield from iterator - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: logging.info( f"StatefulDummyTextDataset.state_dict(): current state is {self._active_iterator_state}" ) return self._active_iterator_state - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: logging.info( f"StatefulDummyTextDataset.load_state_dict(): state to load is {state_dict}" ) @@ -97,9 +118,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class DummyDPODataset(IterableDataset): """Dummy DPO dataset that returns tensor-based samples.""" + def __init__(self, samples): self._samples = samples + @property + def info(self) -> DatasetInfo: + """Returns dataset information.""" + return DatasetInfo(name="DummyDPODataset", weight=1.0, children=()) + def __iter__(self): yield from self._samples @@ -130,42 +157,56 @@ def create_dense_mask_from_mask_mod( ) return dense_mask + @pytest.fixture def dpo_packer(): - return DPOPacker(padding_idx=999, ignore_idx=-100) + packer = DPOPacker(padding_idx=999, ignore_idx=-100) + packer.set_dataset_name("TestDPODataset") + return packer -def assert_pack_structure(pack, packer_type, target_tokens, expected_docs): + +def assert_pack_structure( + pack: PackType, packer_type: str, target_tokens: int, expected_docs: int +) -> None: """ Helper to validate pack structure consistently across tests. Args: - pack: The pack dictionary to validate - packer_type: Type of packer ("text" or "dpo") - target_tokens: Expected sequence length - expected_docs: Expected number of unique document IDs + pack (PackType): The pack dictionary to validate + packer_type (str): Type of packer ("text" or "dpo") + target_tokens (int): Expected sequence length + expected_docs (int): Expected number of unique document IDs """ - assert pack["tokens"].shape == (target_tokens,) - assert pack["labels"].shape == (target_tokens,) - assert pack["document_ids"].shape == (target_tokens,) + # Cast to tensor since we know these keys contain tensors + tokens: torch.Tensor = pack["tokens"] + labels: torch.Tensor = pack["labels"] + document_ids: torch.Tensor = pack["document_ids"] + + assert tokens.shape == (target_tokens,) + assert labels.shape == (target_tokens,) + assert document_ids.shape == (target_tokens,) if packer_type == "dpo": assert "chosen_response_mask" in pack assert "rejected_response_mask" in pack # Verify that a token cannot be part of both the chosen and rejected response. - assert not (pack["chosen_response_mask"] & pack["rejected_response_mask"]).any() + chosen_mask: torch.Tensor = pack["chosen_response_mask"] + rejected_mask: torch.Tensor = pack["rejected_response_mask"] + assert not (chosen_mask & rejected_mask).any() # Verify document boundaries - assert torch.unique(pack["document_ids"]).numel() == expected_docs + assert torch.unique(document_ids).numel() == expected_docs -def assert_attention_mask_properties(mask, doc_ids): + +def assert_attention_mask_properties(mask: torch.Tensor, doc_ids: torch.Tensor) -> None: """ Helper to validate attention mask properties for packed sequences. Verifies causal attention within documents and proper masking boundaries. Args: - mask: Attention mask tensor of shape (batch_size, seq_len, seq_len) - doc_ids: Document ID tensor of shape (batch_size, seq_len) + mask (torch.Tensor): Attention mask tensor of shape (batch_size, seq_len, seq_len) + doc_ids (torch.Tensor): Document ID tensor of shape (batch_size, seq_len) """ batch_size, seq_len, _ = mask.shape @@ -183,13 +224,18 @@ def assert_attention_mask_properties(mask, doc_ids): # --- Test Classes --- + @pytest.fixture def text_packer(): - return TextPacker(padding_idx=999, ignore_idx=-100) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packer.set_dataset_name("TestTextDataset") + return packer + +@pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") class TestTextPacker: """Test TextPacker methods, attention masks, and integration workflow""" - + def test_create_empty_pack(self, text_packer): """Test empty pack creation for TextPacker""" pack = text_packer.create_empty_pack() @@ -207,9 +253,12 @@ def test_get_sample_size(self, text_packer): samples = [ {"tokens": torch.tensor([1, 2, 3]), "labels": torch.tensor([4, 5, 6])}, {"tokens": torch.tensor([7]), "labels": torch.tensor([8])}, - {"tokens": torch.tensor([9, 10, 11, 12]), "labels": torch.tensor([13, 14, 15, 16])}, + { + "tokens": torch.tensor([9, 10, 11, 12]), + "labels": torch.tensor([13, 14, 15, 16]), + }, ] - + expected_sizes = [3, 1, 4] for sample, expected_size in zip(samples, expected_sizes): assert text_packer.get_sample_size(sample) == expected_size @@ -217,18 +266,18 @@ def test_get_sample_size(self, text_packer): def test_add_multiple_samples_to_pack(self, text_packer): """Test adding multiple samples to same pack""" pack = text_packer.create_empty_pack() - + samples = [ {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])}, {"tokens": torch.tensor([5, 6, 7]), "labels": torch.tensor([8, 9, 10])}, {"tokens": torch.tensor([11]), "labels": torch.tensor([12])}, ] - + # Add all samples for i, sample in enumerate(samples): docs_consumed = text_packer.add_sample_to_pack(pack, sample, next_doc_id=i) assert docs_consumed == 1 - + # Verify pack contents assert len(pack["tokens"]) == 3 torch.testing.assert_close(pack["tokens"][0], torch.tensor([1, 2])) @@ -247,14 +296,16 @@ def test_finalize_pack_multiple_samples(self, text_packer): "input_pos": [torch.tensor([0, 1]), torch.tensor([0, 1, 2])], "metrics": [], } - - result = text_packer.finalize_pack(pack, target_tokens_per_pack=8, next_doc_id=2) - + + result = text_packer.finalize_pack( + pack, target_tokens_per_pack=8, next_doc_id=2 + ) + expected_tokens = torch.tensor([1, 2, 3, 4, 5, 999, 999, 999]) expected_labels = torch.tensor([6, 7, 8, 9, 10, -100, -100, -100]) expected_doc_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2]) expected_input_pos = torch.tensor([0, 1, 0, 1, 2, 0, 0, 0]) - + torch.testing.assert_close(result["tokens"], expected_tokens) torch.testing.assert_close(result["labels"], expected_labels) torch.testing.assert_close(result["document_ids"], expected_doc_ids) @@ -265,36 +316,40 @@ def test_text_causal_mask(self, device): text_packer = TextPacker(padding_idx=0) # One sample in batch: doc 0 (len 2), doc 1 (len 2), doc 2 (len 3) doc_ids = torch.tensor([[0, 0, 1, 1, 2, 2, 2]], device=device) - - expected_mask = torch.tensor([ - # q k-> 0 1 2 3 4 5 6 - [1, 0, 0, 0, 0, 0, 0], # 0 - [1, 1, 0, 0, 0, 0, 0], # 1 - [0, 0, 1, 0, 0, 0, 0], # 2 - [0, 0, 1, 1, 0, 0, 0], # 3 - [0, 0, 0, 0, 1, 0, 0], # 4 - [0, 0, 0, 0, 1, 1, 0], # 5 - [0, 0, 0, 0, 1, 1, 1], # 6 - ], dtype=torch.bool, device=device).unsqueeze(0) - + + expected_mask = torch.tensor( + [ + # q k-> 0 1 2 3 4 5 6 + [1, 0, 0, 0, 0, 0, 0], # 0 + [1, 1, 0, 0, 0, 0, 0], # 1 + [0, 0, 1, 0, 0, 0, 0], # 2 + [0, 0, 1, 1, 0, 0, 0], # 3 + [0, 0, 0, 0, 1, 0, 0], # 4 + [0, 0, 0, 0, 1, 1, 0], # 5 + [0, 0, 0, 0, 1, 1, 1], # 6 + ], + dtype=torch.bool, + device=device, + ).unsqueeze(0) + actual_mask = create_dense_mask_from_mask_mod(text_packer, doc_ids) torch.testing.assert_close(actual_mask, expected_mask) def test_text_packing_workflow_two_packs(self): """Test complete text workflow that creates exactly 2 packs with multiple samples""" # Design: Pack1=[3,2], Pack2=[4] to create 2 packs - sample_sizes = [3, 2, 4] + sample_sizes = [3, 2, 4] target_tokens = 6 - + dataset = DummyTextDataset(sample_sizes) text_packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=text_packer, target_tokens_per_pack=target_tokens ) - + packs = list(packed_dataset) assert len(packs) == 2 - + # Pack 1: samples 0(size 3) + 1(size 2) + padding(1) pack1 = packs[0] assert pack1["tokens"].shape == (target_tokens,) @@ -305,9 +360,9 @@ def test_text_packing_workflow_two_packs(self): torch.testing.assert_close(pack1["tokens"], expected_tokens_1) torch.testing.assert_close(pack1["document_ids"], expected_doc_ids_1) torch.testing.assert_close(pack1["input_pos"], expected_input_pos_1) - + # Pack 2: sample 2(size 4) + padding(2) - single sample pack - pack2 = packs[1] + pack2 = packs[1] assert (pack2["labels"] != -100).sum() == 4 # 4 real tokens expected_tokens_2 = torch.tensor([2, 2, 2, 2, 999, 999]) expected_doc_ids_2 = torch.tensor([0, 0, 0, 0, 1, 1]) @@ -316,9 +371,11 @@ def test_text_packing_workflow_two_packs(self): torch.testing.assert_close(pack2["document_ids"], expected_doc_ids_2) torch.testing.assert_close(pack2["input_pos"], expected_input_pos_2) + +@pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") class TestDPOPacker: """Test DPOPacker methods, attention masks, and integration workflow""" - + def test_create_empty_pack(self, dpo_packer): """Test empty pack creation for DPOPacker""" pack = dpo_packer.create_empty_pack() @@ -356,9 +413,9 @@ def test_get_sample_size(self, dpo_packer): "chosen_response_only_labels": torch.tensor([17, 18, 19]), "rejected_response_only_ids": torch.tensor([20, 21]), "rejected_response_only_labels": torch.tensor([20, 21]), - } + }, ] - + expected_sizes = [6, 9, 6] # [2+2+2, 3+2+4, 1+3+2] for sample, expected_size in zip(samples, expected_sizes): assert dpo_packer.get_sample_size(sample) == expected_size @@ -380,14 +437,16 @@ def test_add_multiple_samples_to_pack(self, dpo_packer): "chosen_response_only_labels": torch.tensor([9]), "rejected_response_only_ids": torch.tensor([10, 11]), "rejected_response_only_labels": torch.tensor([10, 11]), - } + }, ] - + # Add all samples for i, sample in enumerate(samples): - docs_consumed = dpo_packer.add_sample_to_pack(pack, sample, next_doc_id=i*3) + docs_consumed = dpo_packer.add_sample_to_pack( + pack, sample, next_doc_id=i * 3 + ) assert docs_consumed == 3 # prompt + chosen + rejected - + # Verify pack contents assert len(pack["tokens"]) == 2 # First sample: [1,2,3,4,5,6] @@ -399,14 +458,14 @@ def test_add_multiple_samples_to_pack(self, dpo_packer): pack["document_ids"][0], torch.tensor([0, 0, 1, 1, 2, 2]) ) torch.testing.assert_close( - pack["chosen_response_mask"][0], - torch.tensor([False, False, True, True, False, False]) + pack["chosen_response_mask"][0], + torch.tensor([False, False, True, True, False, False]), ) torch.testing.assert_close( - pack["rejected_response_mask"][0], - torch.tensor([False, False, False, False, True, True]) + pack["rejected_response_mask"][0], + torch.tensor([False, False, False, False, True, True]), ) - + # Second sample: [7,8,9,10,11] torch.testing.assert_close(pack["tokens"][1], torch.tensor([7, 8, 9, 10, 11])) torch.testing.assert_close( @@ -416,12 +475,12 @@ def test_add_multiple_samples_to_pack(self, dpo_packer): pack["document_ids"][1], torch.tensor([3, 3, 4, 5, 5]) ) torch.testing.assert_close( - pack["chosen_response_mask"][1], - torch.tensor([False, False, True, False, False]) + pack["chosen_response_mask"][1], + torch.tensor([False, False, True, False, False]), ) torch.testing.assert_close( - pack["rejected_response_mask"][1], - torch.tensor([False, False, False, True, True]) + pack["rejected_response_mask"][1], + torch.tensor([False, False, False, True, True]), ) def test_finalize_pack_multiple_dpo_samples(self, dpo_packer): @@ -447,19 +506,55 @@ def test_finalize_pack_multiple_dpo_samples(self, dpo_packer): dpo_packer.add_sample_to_pack(pack, sample2, next_doc_id=3) # docs 3, 4, 5 # Total tokens = 6 (sample1) + 4 (sample2) = 10 - result = dpo_packer.finalize_pack(pack, target_tokens_per_pack=12, next_doc_id=6) + result = dpo_packer.finalize_pack( + pack, target_tokens_per_pack=12, next_doc_id=6 + ) expected_tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 999, 999]) - expected_labels = torch.tensor([-100, -100, 3, 4, 5, 6, -100, 8, 9, 10, -100, -100]) + expected_labels = torch.tensor( + [-100, -100, 3, 4, 5, 6, -100, 8, 9, 10, -100, -100] + ) expected_doc_ids = torch.tensor([0, 0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6]) - expected_chosen_mask = torch.tensor([False, False, True, True, False, False, False, True, False, False, False, False]) - expected_rejected_mask = torch.tensor([False, False, False, False, True, True, False, False, True, True, False, False]) + expected_chosen_mask = torch.tensor( + [ + False, + False, + True, + True, + False, + False, + False, + True, + False, + False, + False, + False, + ] + ) + expected_rejected_mask = torch.tensor( + [ + False, + False, + False, + False, + True, + True, + False, + False, + True, + True, + False, + False, + ] + ) torch.testing.assert_close(result["tokens"], expected_tokens) torch.testing.assert_close(result["labels"], expected_labels) torch.testing.assert_close(result["document_ids"], expected_doc_ids) torch.testing.assert_close(result["chosen_response_mask"], expected_chosen_mask) - torch.testing.assert_close(result["rejected_response_mask"], expected_rejected_mask) + torch.testing.assert_close( + result["rejected_response_mask"], expected_rejected_mask + ) def test_dpo_specialized_mask(self, device): """ @@ -479,28 +574,36 @@ def test_dpo_specialized_mask(self, device): batch_doc_ids = torch.stack([doc_ids_1, doc_ids_2]) # --- Manually create the expected mask for Pack 1 --- - mask1 = torch.tensor([ - # k_idx -> P C C R P C R (k_idx) - [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) can see self - [1, 1, 0, 0, 0, 0, 0], # q=1 (C_A) can see P_A and self (causal) - [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) can see P_A and C_A (causal) - [1, 0, 0, 1, 0, 0, 0], # q=3 (R_A) can see P_A and self - [0, 0, 0, 0, 1, 0, 0], # q=4 (P_B) can see self - [0, 0, 0, 0, 1, 1, 0], # q=5 (C_B) can see P_B and self - [0, 0, 0, 0, 1, 0, 1], # q=6 (R_B) can see P_B and self - ], dtype=torch.bool, device=device) + mask1 = torch.tensor( + [ + # k_idx -> P C C R P C R (k_idx) + [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) can see self + [1, 1, 0, 0, 0, 0, 0], # q=1 (C_A) can see P_A and self (causal) + [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) can see P_A and C_A (causal) + [1, 0, 0, 1, 0, 0, 0], # q=3 (R_A) can see P_A and self + [0, 0, 0, 0, 1, 0, 0], # q=4 (P_B) can see self + [0, 0, 0, 0, 1, 1, 0], # q=5 (C_B) can see P_B and self + [0, 0, 0, 0, 1, 0, 1], # q=6 (R_B) can see P_B and self + ], + dtype=torch.bool, + device=device, + ) # --- Manually create the expected mask for Pack 2 --- - mask2 = torch.tensor([ - # q_idx, P P C C R Pad Pad(k_idx) - [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) - [1, 1, 0, 0, 0, 0, 0], # q=1 (P_A) - [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) - [1, 1, 1, 1, 0, 0, 0], # q=3 (C_A) - [1, 1, 0, 0, 1, 0, 0], # q=4 (R_A) - [0, 0, 0, 0, 0, 1, 0], # q=5 (Pad) - [0, 0, 0, 0, 0, 1, 1], # q=6 (Pad) - ], dtype=torch.bool, device=device) + mask2 = torch.tensor( + [ + # q_idx, P P C C R Pad Pad(k_idx) + [1, 0, 0, 0, 0, 0, 0], # q=0 (P_A) + [1, 1, 0, 0, 0, 0, 0], # q=1 (P_A) + [1, 1, 1, 0, 0, 0, 0], # q=2 (C_A) + [1, 1, 1, 1, 0, 0, 0], # q=3 (C_A) + [1, 1, 0, 0, 1, 0, 0], # q=4 (R_A) + [0, 0, 0, 0, 0, 1, 0], # q=5 (Pad) + [0, 0, 0, 0, 0, 1, 1], # q=6 (Pad) + ], + dtype=torch.bool, + device=device, + ) expected_mask = torch.stack([mask1, mask2]) @@ -530,18 +633,18 @@ def test_dpo_packing_workflow_two_packs(self): "chosen_response_only_labels": torch.tensor([12, 13]), "rejected_response_only_ids": torch.tensor([14, 15]), "rejected_response_only_labels": torch.tensor([14, 15]), - } + }, ] - + dataset = DummyDPODataset(samples) dpo_packer = DPOPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=dpo_packer, target_tokens_per_pack=10 ) - + packs = list(packed_dataset) assert len(packs) == 2 # Pack1: samples 0+1 (4+5=9), Pack2: sample 2 (6) - + # Pack 1: samples 0+1 (9 tokens) + padding (1) pack1 = packs[0] assert pack1["tokens"].shape == (10,) @@ -549,18 +652,24 @@ def test_dpo_packing_workflow_two_packs(self): assert "rejected_response_mask" in pack1 non_padding_1 = (pack1["tokens"] != 999).sum() assert non_padding_1 == 9 - + # Pack 2: sample 2 (6 tokens) + padding (4) pack2 = packs[1] non_padding_2 = (pack2["tokens"] != 999).sum() assert non_padding_2 == 6 - + # Verify masks are mutually exclusive - chosen_and_rejected_1 = pack1["chosen_response_mask"] & pack1["rejected_response_mask"] - chosen_and_rejected_2 = pack2["chosen_response_mask"] & pack2["rejected_response_mask"] + chosen_and_rejected_1 = ( + pack1["chosen_response_mask"] & pack1["rejected_response_mask"] + ) + chosen_and_rejected_2 = ( + pack2["chosen_response_mask"] & pack2["rejected_response_mask"] + ) assert not chosen_and_rejected_1.any() assert not chosen_and_rejected_2.any() + +@pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") class TestIterablePackedDataset: """Test IterablePackedDataset functionality - buffer efficiency, checkpointing, edge cases""" @@ -569,53 +678,59 @@ def test_buffer_efficiency(self): # Test case where buffer helps vs hurts - order matters for first-fit sample_sizes = [3, 4, 1, 2] # Total 10 tokens target_tokens = 6 - + # With large buffer: can see all samples and pick best fit [3,1,2], [4] dataset1 = DummyTextDataset(sample_sizes) packer1 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset1 = IterablePackedDataset( - dataset=dataset1, packer=packer1, - target_tokens_per_pack=target_tokens, buffer_size=10 + dataset=dataset1, + packer=packer1, + target_tokens_per_pack=target_tokens, + buffer_size=10, ) packs_buffered = list(packed_dataset1) - + # With small buffer: greedy first-fit [3], [4,1], [2] dataset2 = DummyTextDataset(sample_sizes) packer2 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset2 = IterablePackedDataset( - dataset=dataset2, packer=packer2, - target_tokens_per_pack=target_tokens, buffer_size=1 + dataset=dataset2, + packer=packer2, + target_tokens_per_pack=target_tokens, + buffer_size=1, ) packs_unbuffered = list(packed_dataset2) - + # Buffer should create fewer packs (more efficient) assert len(packs_buffered) < len(packs_unbuffered) assert len(packs_buffered) == 2 # [3,1,2], [4] assert len(packs_unbuffered) == 3 # [3], [4,1], [2] - + # Verify both preserve all tokens total_buffered = sum((p["labels"] != -100).sum().item() for p in packs_buffered) - total_unbuffered = sum((p["labels"] != -100).sum().item() for p in packs_unbuffered) + total_unbuffered = sum( + (p["labels"] != -100).sum().item() for p in packs_unbuffered + ) assert total_buffered == total_unbuffered == sum(sample_sizes) def test_oversized_sample_dropping(self): """Test that oversized samples are dropped""" sample_sizes = [3, 10, 2, 8, 1] # 10 and 8 are oversized for target=6 - target_tokens = 6 - + target_tokens = 5 + dataset = DummyTextDataset(sample_sizes) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=target_tokens ) - + packs = list(packed_dataset) - + # Only samples 3, 2, 1 should be packed (oversized 10, 8 dropped) total_packed_tokens = sum((p["labels"] != -100).sum().item() for p in packs) expected_tokens = 3 + 2 + 1 # Only non-oversized samples assert total_packed_tokens == expected_tokens - + # Should create 2 packs: [3, 2], [1] assert len(packs) == 2 @@ -623,36 +738,42 @@ def test_checkpoint_and_resume(self): """Test checkpointing and resumption functionality""" sample_sizes = [3, 2, 5, 4, 1, 6] # Total 21 tokens target_tokens_per_pack = 6 - + # First run: iterate partially dataset1 = StatefulDummyTextDataset(sample_sizes) packer1 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset1 = IterablePackedDataset( - dataset=dataset1, packer=packer1, - target_tokens_per_pack=target_tokens_per_pack, buffer_size=4 + dataset=dataset1, + packer=packer1, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=4, ) - + # Get first pack and save state packed_iterator1 = iter(packed_dataset1) pack1_partial = next(packed_iterator1) state = packed_dataset1.state_dict() - + # Second run: resume from checkpoint dataset2 = StatefulDummyTextDataset(sample_sizes) packer2 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset2 = IterablePackedDataset( - dataset=dataset2, packer=packer2, - target_tokens_per_pack=target_tokens_per_pack, buffer_size=4 + dataset=dataset2, + packer=packer2, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=4, ) packed_dataset2.load_state_dict(state) - + resumed_packs = list(packed_dataset2) - + # Verify resumption worked (buffer contents are lost, so some samples skipped) assert len(resumed_packs) >= 1 - total_resumed_tokens = sum((p["labels"] != -100).sum().item() for p in resumed_packs) + total_resumed_tokens = sum( + (p["labels"] != -100).sum().item() for p in resumed_packs + ) assert total_resumed_tokens > 0 - + # Verify that together, first pack + resumed packs contain reasonable amount of data # (not all data since buffer loss causes some samples to be skipped) total_first_tokens = (pack1_partial["labels"] != -100).sum().item() @@ -667,47 +788,52 @@ def test_multiple_iterations_same_dataset(self): packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=4 ) - + # First iteration packs1 = list(packed_dataset) # Second iteration should produce same result packs2 = list(packed_dataset) - + assert len(packs1) == len(packs2) for p1, p2 in zip(packs1, packs2): torch.testing.assert_close(p1["tokens"], p2["tokens"]) torch.testing.assert_close(p1["document_ids"], p2["document_ids"]) @pytest.mark.parametrize( - "sample_sizes,target_tokens,buffer_size,expected_packs,scenario", [ + "sample_sizes,target_tokens,buffer_size,expected_packs,scenario", + [ ([3, 2, 4], 8, 10, 2, "basic_packing"), # Pack1: [3,2]+pad, Pack2: [4]+pad ([4, 3], 8, 10, 1, "partial_final_pack"), # Pack1: [4,3]+pad ([], 8, 10, 0, "empty_dataset"), ([5], 10, 10, 1, "single_sample"), ([5, 5, 5], 5, 10, 3, "exact_fit"), ([2, 3, 1], 5, 1, 2, "small_target_and_buffer"), # Pack1: [2,3], Pack2: [1] - ] + ], ) - def test_scenarios(self, sample_sizes, target_tokens, buffer_size, expected_packs, scenario): + def test_scenarios( + self, sample_sizes, target_tokens, buffer_size, expected_packs, scenario + ): """Parametrized edge case testing""" dataset = DummyTextDataset(sample_sizes) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( - dataset=dataset, packer=packer, - target_tokens_per_pack=target_tokens, buffer_size=buffer_size + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens, + buffer_size=buffer_size, ) - + packs = list(packed_dataset) assert len(packs) == expected_packs, f"Failed scenario: {scenario}" - + # Verify output format consistency for all scenarios for pack in packs: assert pack["tokens"].shape == (target_tokens,) assert pack["labels"].shape == (target_tokens,) assert pack["document_ids"].shape == (target_tokens,) assert pack["input_pos"].shape == (target_tokens,) - + # Verify no token loss if sample_sizes: # Skip for empty dataset total_packed = sum((p["labels"] != -100).sum().item() for p in packs) - assert total_packed == sum(sample_sizes) \ No newline at end of file + assert total_packed == sum(sample_sizes) diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 7f97cbeb12..36a8afb5a5 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -19,9 +19,9 @@ def collate_packed( Generic collate function for packed samples from an IterablePackedDataset. Stacks tensors from all samples in the batch, while keeping non-tensor values - as lists. Handles metrics by extending them into a single list. Delegates - attention mask creation to a provided `mask_fn` callable that expects - `document_ids` and `device` parameters to generate masks on-the-fly for + as lists. Handles metrics by extending them into a single list. Delegates + attention mask creation to a provided `mask_fn` callable that expects + `document_ids` and `device` parameters to generate masks on-the-fly for packed sequences. Args: @@ -42,8 +42,10 @@ def collate_packed( first_sample_keys = batch[0].keys() for sample in batch: if sample.keys() != first_sample_keys: - raise ValueError(f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}") - + raise ValueError( + f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}" + ) + keys_to_stack = first_sample_keys collated = {} diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index c523e75a8f..8e731e2188 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -18,15 +18,12 @@ from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._interleaved import InterleavedDataset -from torchtune.datasets._iterable_packed import ( - IterablePackedDataset, - TextPacker, -) from torchtune.datasets._iterable_base import ( DatasetInfo, InfiniteTuneIterableDataset, TuneIterableDataset, ) +from torchtune.datasets._iterable_packed import IterablePackedDataset, TextPacker from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 5a03c8569c..a050abcfd5 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -15,9 +15,9 @@ ) from torchdata.stateful_dataloader import Stateful from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.data._metrics import AggregationType, Metric +from torchtune.data.metrics import AggregationType, Metric -from torchtune.datasets import TuneIterableDataset +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class Packer(ABC, Generic[SampleType]): >>> packer.add_sample_to_pack(pack, sample, next_doc_id=0) >>> final_pack = packer.finalize_pack(pack, target_tokens_per_pack=5, next_doc_id=1) >>> mask = packer.create_block_mask(final_pack["document_ids"].unsqueeze(0), device="cpu") - + Raises: RuntimeError: If FlexAttention is not supported in the current environment. """ @@ -83,11 +83,11 @@ def set_dataset_name(self, dataset_name: str) -> None: def create_empty_pack(self) -> dict[str, list[Any]]: """ Creates an empty pack structure for accumulating samples. - + Returns: dict[str, list[Any]]: An empty structure that can accumulate sample data and be converted to tensors by finalize_pack(). - + Example: >>> packer.create_empty_pack() {"tokens": [], "labels": []} @@ -129,18 +129,18 @@ def add_sample_to_pack( Example: >>> packer = TextPacker(padding_idx=0, ignore_idx=-100) - >>> pack = {"tokens": [torch.tensor([1, 2])], - ... "labels": [torch.tensor([3, 4])], - ... "document_ids": [torch.tensor([0, 0])], - ... "input_pos": [torch.tensor([0, 1])], + >>> pack = {"tokens": [torch.tensor([1, 2])], + ... "labels": [torch.tensor([3, 4])], + ... "document_ids": [torch.tensor([0, 0])], + ... "input_pos": [torch.tensor([0, 1])], ... "metrics": []} - >>> sample = {"tokens": torch.tensor([5, 6]), + >>> sample = {"tokens": torch.tensor([5, 6]), ... "labels": torch.tensor([7, 8])} >>> added_docs = packer.add_sample_to_pack(pack, sample, next_doc_id=1) >>> print(pack) {"tokens": [torch.tensor([1, 2]), torch.tensor([5, 6])], - "labels": [torch.tensor([3, 4]), torch.tensor([7, 8])], - "document_ids": [torch.tensor([0, 0]), torch.tensor([1, 1])], + "labels": [torch.tensor([3, 4]), torch.tensor([7, 8])], + "document_ids": [torch.tensor([0, 0]), torch.tensor([1, 1])], "input_pos": [torch.tensor([0, 1]), torch.tensor([0, 1])], "metrics": []} >>> print(added_docs) 1 @@ -164,17 +164,17 @@ def finalize_pack( Example: >>> packer = TextPacker(padding_idx=999, ignore_idx=-100) - >>> pack = {"tokens": [torch.tensor([1, 2])], - ... "labels": [torch.tensor([3, 4])], - ... "document_ids": [torch.tensor([0, 0])], + >>> pack = {"tokens": [torch.tensor([1, 2])], + ... "labels": [torch.tensor([3, 4])], + ... "document_ids": [torch.tensor([0, 0])], ... "input_pos": [torch.tensor([0, 1])], "metrics": []} >>> target_tokens_per_pack = 4 >>> next_doc_id = 1 >>> result = packer.finalize_pack(pack, target_tokens_per_pack, next_doc_id) >>> print(result) {"tokens": torch.tensor([1, 2, 999, 999]), - "labels": torch.tensor([3, 4, -100, -100]), - "document_ids": torch.tensor([0, 0, 1, 1]), + "labels": torch.tensor([3, 4, -100, -100]), + "document_ids": torch.tensor([0, 0, 1, 1]), "input_pos": torch.tensor([0, 1, 0, 0]), "metrics": [...]} """ pass @@ -238,11 +238,9 @@ def _mask_mod_for_flex(b, h, q_idx, kv_idx): ) -class IterablePackedDataset( - TuneIterableDataset[PackType], Stateful, Generic[SampleType] -): +class IterablePackedDataset(InfiniteTuneIterableDataset, Stateful, Generic[SampleType]): """ - Wraps a `TuneIterableDataset` to combine multiple samples into a single, + Wraps a `InfiniteTuneIterableDataset` to combine multiple samples into a single, fixed-size "pack". This is highly efficient for training as it minimizes padding and ensures consistent batch shapes. @@ -259,7 +257,7 @@ class IterablePackedDataset( allowing training to be resumed seamlessly. Args: - dataset (TuneIterableDataset[SampleType]): The `TuneIterableDataset` to pack. + dataset (InfiniteTuneIterableDataset): The `InfiniteTuneIterableDataset` to pack. packer (Packer[SampleType]): The `Packer` that defines the packing strategy for the dataset format (e.g. `TextPacker`). target_tokens_per_pack (int): The target number of tokens for each pack. @@ -267,12 +265,12 @@ class IterablePackedDataset( best fit. A larger buffer may improve packing efficiency at the cost of memory. Buffer samples are discarded if resuming from a checkpoint. Default is 100. - dataset_name (str): The name of the dataset, used for metrics. + dataset_name (str): The name of this packed dataset, used for metrics. Defaults to "IterablePackedDataset". """ def __init__( self, - dataset: TuneIterableDataset[SampleType], + dataset: InfiniteTuneIterableDataset, packer: Packer[SampleType], target_tokens_per_pack: int, buffer_size: int = 100, @@ -289,10 +287,15 @@ def __init__( self._reset_packer_state() + # Validate that the dataset names are unique + self._validate_unique_dataset_names() + @property - def dataset_name(self) -> str: - """Returns the dataset name, used for metrics tracking.""" - return self._dataset_name + def info(self) -> DatasetInfo: + """Returns hierarchical dataset information including child dataset info.""" + return DatasetInfo( + name=self._dataset_name, weight=1.0, children=(self.dataset.info,) + ) def _reset_packer_state(self) -> None: """Resets the packer's internal state for a new or resumed iteration.""" @@ -494,7 +497,6 @@ class TextPacker(Packer[dict[str, torch.Tensor]]): def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): super().__init__(padding_idx, ignore_idx) - self.dataset_name = "packed_dataset" # Default name def set_dataset_name(self, dataset_name: str) -> None: """ @@ -525,7 +527,7 @@ def add_sample_to_pack( # Append tensors directly to pack lists pack["tokens"].append(sample["tokens"]) pack["labels"].append(sample["labels"]) - + # Generate metadata as tensors pack["document_ids"].append( torch.full((seq_len,), next_doc_id, dtype=torch.long, device="cpu") @@ -559,16 +561,14 @@ def finalize_pack( pack["document_ids"].append( torch.full((num_padding,), next_doc_id, dtype=torch.long) ) - pack["input_pos"].append( - torch.zeros(num_padding, dtype=torch.long) - ) + pack["input_pos"].append(torch.zeros(num_padding, dtype=torch.long)) # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) padding_metric = Metric( dataset_name=self.dataset_name, - name="pct_of_tokens_padded", + metric_name="pct_of_tokens_padded", value=padding_pct, agg_type=AggregationType.MEAN, ) @@ -576,10 +576,26 @@ def finalize_pack( # Concatenate all tensor lists efficiently result = { - "tokens": torch.cat(pack["tokens"]) if pack["tokens"] else torch.empty(0, dtype=torch.long), - "labels": torch.cat(pack["labels"]) if pack["labels"] else torch.empty(0, dtype=torch.long), - "document_ids": torch.cat(pack["document_ids"]) if pack["document_ids"] else torch.empty(0, dtype=torch.long), - "input_pos": torch.cat(pack["input_pos"]) if pack["input_pos"] else torch.empty(0, dtype=torch.long), + "tokens": ( + torch.cat(pack["tokens"]) + if pack["tokens"] + else torch.empty(0, dtype=torch.long) + ), + "labels": ( + torch.cat(pack["labels"]) + if pack["labels"] + else torch.empty(0, dtype=torch.long) + ), + "document_ids": ( + torch.cat(pack["document_ids"]) + if pack["document_ids"] + else torch.empty(0, dtype=torch.long) + ), + "input_pos": ( + torch.cat(pack["input_pos"]) + if pack["input_pos"] + else torch.empty(0, dtype=torch.long) + ), "metrics": pack["metrics"], } @@ -623,7 +639,6 @@ class DPOPacker(Packer[dict[str, torch.Tensor]]): def __init__(self, padding_idx: int, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX): super().__init__(padding_idx, ignore_idx) - self.dataset_name = "packed_dataset" # Default name def set_dataset_name(self, dataset_name: str) -> None: """ @@ -674,8 +689,7 @@ def add_sample_to_pack( # 2. Create labels: [ignore_idx for prompt, chosen_labels, rejected_labels] labels = torch.cat( [ - torch.full( - (prompt_len,), self.ignore_idx, dtype=torch.long), + torch.full((prompt_len,), self.ignore_idx, dtype=torch.long), sample["chosen_response_only_labels"], sample["rejected_response_only_labels"], ] @@ -684,12 +698,9 @@ def add_sample_to_pack( # 3. Create document IDs: prompt(next_doc_id), chosen(next_doc_id+1), rejected(next_doc_id+2) document_ids = torch.cat( [ - torch.full( - (prompt_len,), next_doc_id, dtype=torch.long), - torch.full( - (chosen_len,), next_doc_id + 1, dtype=torch.long), - torch.full( - (rejected_len,), next_doc_id + 2, dtype=torch.long), + torch.full((prompt_len,), next_doc_id, dtype=torch.long), + torch.full((chosen_len,), next_doc_id + 1, dtype=torch.long), + torch.full((rejected_len,), next_doc_id + 2, dtype=torch.long), ] ) @@ -766,7 +777,7 @@ def finalize_pack( padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) padding_metric = Metric( dataset_name=self.dataset_name, - name="pct_of_tokens_padded", + metric_name="pct_of_tokens_padded", value=padding_pct, agg_type=AggregationType.MEAN, ) @@ -774,12 +785,36 @@ def finalize_pack( # Concatenate all tensor lists result = { - "tokens": torch.cat(pack["tokens"]) if pack["tokens"] else torch.empty(0, dtype=torch.long), - "labels": torch.cat(pack["labels"]) if pack["labels"] else torch.empty(0, dtype=torch.long), - "document_ids": torch.cat(pack["document_ids"]) if pack["document_ids"] else torch.empty(0, dtype=torch.long), - "input_pos": torch.cat(pack["input_pos"]) if pack["input_pos"] else torch.empty(0, dtype=torch.long), - "chosen_response_mask": torch.cat(pack["chosen_response_mask"]) if pack["chosen_response_mask"] else torch.empty(0, dtype=torch.bool), - "rejected_response_mask": torch.cat(pack["rejected_response_mask"]) if pack["rejected_response_mask"] else torch.empty(0, dtype=torch.bool), + "tokens": ( + torch.cat(pack["tokens"]) + if pack["tokens"] + else torch.empty(0, dtype=torch.long) + ), + "labels": ( + torch.cat(pack["labels"]) + if pack["labels"] + else torch.empty(0, dtype=torch.long) + ), + "document_ids": ( + torch.cat(pack["document_ids"]) + if pack["document_ids"] + else torch.empty(0, dtype=torch.long) + ), + "input_pos": ( + torch.cat(pack["input_pos"]) + if pack["input_pos"] + else torch.empty(0, dtype=torch.long) + ), + "chosen_response_mask": ( + torch.cat(pack["chosen_response_mask"]) + if pack["chosen_response_mask"] + else torch.empty(0, dtype=torch.bool) + ), + "rejected_response_mask": ( + torch.cat(pack["rejected_response_mask"]) + if pack["rejected_response_mask"] + else torch.empty(0, dtype=torch.bool) + ), "metrics": pack["metrics"], } From 23bd9fb58e37c432c82aed5c3f3967e1d6766a44 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 16:17:39 -0400 Subject: [PATCH 46/48] test collate + dataloader --- .../datasets/test_iterable_packed_dataset.py | 319 +++++++++++------- torchtune/datasets/_iterable_packed.py | 2 +- 2 files changed, 207 insertions(+), 114 deletions(-) diff --git a/tests/torchtune/datasets/test_iterable_packed_dataset.py b/tests/torchtune/datasets/test_iterable_packed_dataset.py index 565cef18af..91666b7e46 100644 --- a/tests/torchtune/datasets/test_iterable_packed_dataset.py +++ b/tests/torchtune/datasets/test_iterable_packed_dataset.py @@ -5,16 +5,16 @@ # LICENSE file in the root directory of this source tree. import logging +from functools import partial from typing import Any, Iterator, Optional import pytest import torch from torch.utils.data import IterableDataset -from torchdata.stateful_dataloader import Stateful +from torchdata.stateful_dataloader import Stateful, StatefulDataLoader -# from torchtune.data._collate import collate_packed +from torchtune.data._collate import collate_packed from torchtune.datasets._iterable_base import DatasetInfo - from torchtune.datasets._iterable_packed import ( DPOPacker, IterablePackedDataset, @@ -22,7 +22,9 @@ PackType, TextPacker, ) +from torchtune.data.metrics import MetricsAggregator from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +from .test_iterable_utils import generate_ckpt # --- Test Fixtures --- @@ -165,63 +167,6 @@ def dpo_packer(): return packer -def assert_pack_structure( - pack: PackType, packer_type: str, target_tokens: int, expected_docs: int -) -> None: - """ - Helper to validate pack structure consistently across tests. - - Args: - pack (PackType): The pack dictionary to validate - packer_type (str): Type of packer ("text" or "dpo") - target_tokens (int): Expected sequence length - expected_docs (int): Expected number of unique document IDs - """ - # Cast to tensor since we know these keys contain tensors - tokens: torch.Tensor = pack["tokens"] - labels: torch.Tensor = pack["labels"] - document_ids: torch.Tensor = pack["document_ids"] - - assert tokens.shape == (target_tokens,) - assert labels.shape == (target_tokens,) - assert document_ids.shape == (target_tokens,) - - if packer_type == "dpo": - assert "chosen_response_mask" in pack - assert "rejected_response_mask" in pack - # Verify that a token cannot be part of both the chosen and rejected response. - chosen_mask: torch.Tensor = pack["chosen_response_mask"] - rejected_mask: torch.Tensor = pack["rejected_response_mask"] - assert not (chosen_mask & rejected_mask).any() - - # Verify document boundaries - assert torch.unique(document_ids).numel() == expected_docs - - -def assert_attention_mask_properties(mask: torch.Tensor, doc_ids: torch.Tensor) -> None: - """ - Helper to validate attention mask properties for packed sequences. - - Verifies causal attention within documents and proper masking boundaries. - - Args: - mask (torch.Tensor): Attention mask tensor of shape (batch_size, seq_len, seq_len) - doc_ids (torch.Tensor): Document ID tensor of shape (batch_size, seq_len) - """ - batch_size, seq_len, _ = mask.shape - - # Verify causal property within documents - for b in range(batch_size): - for doc_id in torch.unique(doc_ids[b]): - doc_indices = (doc_ids[b] == doc_id).nonzero(as_tuple=True)[0] - if not doc_indices.numel(): - continue - # The mask for tokens within a document should be lower-triangular (causal). - doc_mask = mask[b][doc_indices, :][:, doc_indices] - is_causal = torch.all(doc_mask.tril() == doc_mask) - assert is_causal, f"Mask for doc {doc_id} in batch {b} is not causal." - - # --- Test Classes --- @@ -312,27 +257,52 @@ def test_finalize_pack_multiple_samples(self, text_packer): torch.testing.assert_close(result["input_pos"], expected_input_pos) def test_text_causal_mask(self, device): - """Test standard causal masking for TextPacker""" + """ + Verify the correctness of the causal attention mask by manually constructing + the expected mask for a batch containing multiple documents. + """ text_packer = TextPacker(padding_idx=0) - # One sample in batch: doc 0 (len 2), doc 1 (len 2), doc 2 (len 3) - doc_ids = torch.tensor([[0, 0, 1, 1, 2, 2, 2]], device=device) - expected_mask = torch.tensor( + # Batch contains two packs of different layouts. + # Pack 1: docs [A(2), B(3), C(2)] + doc_ids_1 = torch.tensor([0, 0, 1, 1, 1, 2, 2], device=device) + # Pack 2: docs [D(4), E(1), F(2)] + doc_ids_2 = torch.tensor([0, 0, 0, 0, 1, 2, 2], device=device) + batch_doc_ids = torch.stack([doc_ids_1, doc_ids_2]) + + # Manually create the expected mask for the batch + mask1 = torch.tensor( + [ + # k_idx -> A A B B B C C + [1, 0, 0, 0, 0, 0, 0], # q=0 (A) + [1, 1, 0, 0, 0, 0, 0], # q=1 (A) + [0, 0, 1, 0, 0, 0, 0], # q=2 (B) + [0, 0, 1, 1, 0, 0, 0], # q=3 (B) + [0, 0, 1, 1, 1, 0, 0], # q=4 (B) + [0, 0, 0, 0, 0, 1, 0], # q=5 (C) + [0, 0, 0, 0, 0, 1, 1], # q=6 (C) + ], + dtype=torch.bool, + device=device, + ) + mask2 = torch.tensor( [ - # q k-> 0 1 2 3 4 5 6 - [1, 0, 0, 0, 0, 0, 0], # 0 - [1, 1, 0, 0, 0, 0, 0], # 1 - [0, 0, 1, 0, 0, 0, 0], # 2 - [0, 0, 1, 1, 0, 0, 0], # 3 - [0, 0, 0, 0, 1, 0, 0], # 4 - [0, 0, 0, 0, 1, 1, 0], # 5 - [0, 0, 0, 0, 1, 1, 1], # 6 + # k_idx -> D D D D E F F + [1, 0, 0, 0, 0, 0, 0], # q=0 (D) + [1, 1, 0, 0, 0, 0, 0], # q=1 (D) + [1, 1, 1, 0, 0, 0, 0], # q=2 (D) + [1, 1, 1, 1, 0, 0, 0], # q=3 (D) + [0, 0, 0, 0, 1, 0, 0], # q=4 (E) + [0, 0, 0, 0, 0, 1, 0], # q=5 (F) + [0, 0, 0, 0, 0, 1, 1], # q=6 (F) ], dtype=torch.bool, device=device, - ).unsqueeze(0) + ) + expected_mask = torch.stack([mask1, mask2]) - actual_mask = create_dense_mask_from_mask_mod(text_packer, doc_ids) + # Create mask using the strategy and verify + actual_mask = create_dense_mask_from_mask_mod(text_packer, batch_doc_ids) torch.testing.assert_close(actual_mask, expected_mask) def test_text_packing_workflow_two_packs(self): @@ -669,6 +639,109 @@ def test_dpo_packing_workflow_two_packs(self): assert not chosen_and_rejected_2.any() +@pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") +class TestCollatedPacked: + """Test collate_packed function""" + + def test_collate_empty_batch(self): + """Test collating an empty batch""" + result = collate_packed(batch=[], mask_fn=lambda doc_ids, device: None, device="cpu") + assert result == {} + + def test_collate_basic_batch(self): + """Test basic collation functionality""" + # Create mock packed samples + batch = [ + { + "tokens": torch.tensor([1, 2, 3]), + "labels": torch.tensor([4, 5, 6]), + "document_ids": torch.tensor([0, 0, 1]), + "input_pos": torch.tensor([0, 1, 0]), + "metrics": [ + type('Metric', (), {'metric_name': 'test', 'value': 1.0})(), + type('Metric', (), {'metric_name': 'test2', 'value': 2.0})() + ] + }, + { + "tokens": torch.tensor([7, 8]), + "labels": torch.tensor([9, 10]), + "document_ids": torch.tensor([2, 2]), + "input_pos": torch.tensor([0, 1]), + "metrics": [ + type('Metric', (), {'metric_name': 'test3', 'value': 3.0})() + ] + } + ] + + # Mock mask function + def mock_mask_fn(doc_ids, device): + batch_size, seq_len = doc_ids.shape + return torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool) + + result = collate_packed(batch, mock_mask_fn, "cpu") + + # Check tensor stacking + expected_tokens = torch.stack([torch.tensor([1, 2, 3]), torch.tensor([7, 8])]) + expected_labels = torch.stack([torch.tensor([4, 5, 6]), torch.tensor([9, 10])]) + expected_doc_ids = torch.stack([torch.tensor([0, 0, 1]), torch.tensor([2, 2])]) + + torch.testing.assert_close(result["tokens"], expected_tokens) + torch.testing.assert_close(result["labels"], expected_labels) + torch.testing.assert_close(result["document_ids"], expected_doc_ids) + + # Check metrics flattening + assert len(result["metrics"]) == 3 # All metrics from both samples + + # Check mask creation + assert "mask" in result + assert result["mask"].shape == (2, 3, 3) # batch_size=2, seq_len=3 + + def test_collate_different_keys_error(self): + """Test that different keys across samples raises ValueError""" + batch = [ + {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])}, + {"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])} + ] + + def mock_mask_fn(doc_ids, device): + return torch.ones(1, 1, 1) + + with pytest.raises(ValueError, match="All samples must have the same keys"): + collate_packed(batch, mock_mask_fn, "cpu") + + def test_collate_mixed_tensor_non_tensor(self): + """Test collation with mixed tensor and non-tensor data""" + batch = [ + { + "tokens": torch.tensor([1, 2]), + "document_ids": torch.tensor([0, 0]), + "text_data": "sample1", + "metrics": ["DummyMetric1"] + }, + { + "tokens": torch.tensor([3, 4]), + "document_ids": torch.tensor([1, 1]), + "text_data": "sample2", + "metrics": ["DummyMetric2"] + } + ] + + def mock_mask_fn(doc_ids, device): + return torch.ones(2, 2, 2) + + result = collate_packed(batch, mock_mask_fn, "cpu") + + # Tensors should be stacked + expected_tokens = torch.stack([torch.tensor([1, 2]), torch.tensor([3, 4])]) + torch.testing.assert_close(result["tokens"], expected_tokens) + + # Non-tensors should be kept as lists + assert result["text_data"] == ["sample1", "sample2"] + + # Metrics should be flattened + assert result["metrics"] == ["DummyMetric1", "DummyMetric2"] + + @pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") class TestIterablePackedDataset: """Test IterablePackedDataset functionality - buffer efficiency, checkpointing, edge cases""" @@ -735,50 +808,71 @@ def test_oversized_sample_dropping(self): assert len(packs) == 2 def test_checkpoint_and_resume(self): - """Test checkpointing and resumption functionality""" + """Test checkpointing and resumption functionality using StatefulDataLoader""" sample_sizes = [3, 2, 5, 4, 1, 6] # Total 21 tokens target_tokens_per_pack = 6 + batch_size = 2 + + # Setup dataset factory + def create_loader_and_aggregator(): + dataset = StatefulDummyTextDataset(sample_sizes) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = IterablePackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=0, # No buffer for deterministic checkpointing + ) - # First run: iterate partially - dataset1 = StatefulDummyTextDataset(sample_sizes) - packer1 = TextPacker(padding_idx=999, ignore_idx=-100) - packed_dataset1 = IterablePackedDataset( - dataset=dataset1, - packer=packer1, - target_tokens_per_pack=target_tokens_per_pack, - buffer_size=4, - ) - - # Get first pack and save state - packed_iterator1 = iter(packed_dataset1) - pack1_partial = next(packed_iterator1) - state = packed_dataset1.state_dict() - - # Second run: resume from checkpoint - dataset2 = StatefulDummyTextDataset(sample_sizes) - packer2 = TextPacker(padding_idx=999, ignore_idx=-100) - packed_dataset2 = IterablePackedDataset( - dataset=dataset2, - packer=packer2, - target_tokens_per_pack=target_tokens_per_pack, - buffer_size=4, - ) - packed_dataset2.load_state_dict(state) - - resumed_packs = list(packed_dataset2) + collate_fn = partial( + collate_packed, mask_fn=packer.create_block_mask, device="cpu" + ) - # Verify resumption worked (buffer contents are lost, so some samples skipped) - assert len(resumed_packs) >= 1 - total_resumed_tokens = sum( - (p["labels"] != -100).sum().item() for p in resumed_packs + loader = StatefulDataLoader( + packed_dataset, batch_size=batch_size, collate_fn=collate_fn + ) + aggregator = MetricsAggregator() + return loader, aggregator + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + steps_before_checkpoint = 3 + steps_after_checkpoint = 3 + + # Generate checkpoint and resume + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=steps_before_checkpoint, + steps_after_checkpoint=steps_after_checkpoint, + resume_dataloader=loader2, + resume_aggregator=aggregator2, ) - assert total_resumed_tokens > 0 - # Verify that together, first pack + resumed packs contain reasonable amount of data - # (not all data since buffer loss causes some samples to be skipped) - total_first_tokens = (pack1_partial["labels"] != -100).sum().item() - total_all_tokens = total_first_tokens + total_resumed_tokens - assert total_all_tokens < sum(sample_sizes) # Some data lost due to buffer + # Verify that checkpointing and resumption are identical + assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint + assert len(result["resumed_batches"]) == steps_after_checkpoint + + for orig_batch, resumed_batch in zip( + result["post_checkpoint_batches"], result["resumed_batches"] + ): + assert orig_batch.keys() == resumed_batch.keys() + for key in orig_batch: + if isinstance(orig_batch[key], torch.Tensor): + torch.testing.assert_close( + orig_batch[key], + resumed_batch[key], + msg=f"Mismatch in batch key: {key}", + ) + else: + assert ( + orig_batch[key] == resumed_batch[key] + ), f"Mismatch in batch key: {key}" + + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" def test_multiple_iterations_same_dataset(self): """Test that multiple iterations over same packed dataset work correctly""" @@ -788,7 +882,6 @@ def test_multiple_iterations_same_dataset(self): packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=4 ) - # First iteration packs1 = list(packed_dataset) # Second iteration should produce same result diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index a050abcfd5..3271f53afa 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -238,7 +238,7 @@ def _mask_mod_for_flex(b, h, q_idx, kv_idx): ) -class IterablePackedDataset(InfiniteTuneIterableDataset, Stateful, Generic[SampleType]): +class IterablePackedDataset(InfiniteTuneIterableDataset, Generic[SampleType]): """ Wraps a `InfiniteTuneIterableDataset` to combine multiple samples into a single, fixed-size "pack". This is highly efficient for training as it minimizes From fb7b9aabf2235b78c1f1ad855cb6b248bac2da8b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 13:40:00 -0700 Subject: [PATCH 47/48] clean up --- .../datasets/test_iterable_packed_dataset.py | 107 ++++++------------ 1 file changed, 36 insertions(+), 71 deletions(-) diff --git a/tests/torchtune/datasets/test_iterable_packed_dataset.py b/tests/torchtune/datasets/test_iterable_packed_dataset.py index 91666b7e46..71ed1efb30 100644 --- a/tests/torchtune/datasets/test_iterable_packed_dataset.py +++ b/tests/torchtune/datasets/test_iterable_packed_dataset.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging from functools import partial from typing import Any, Iterator, Optional @@ -14,19 +13,17 @@ from torchdata.stateful_dataloader import Stateful, StatefulDataLoader from torchtune.data._collate import collate_packed +from torchtune.data.metrics import MetricsAggregator from torchtune.datasets._iterable_base import DatasetInfo from torchtune.datasets._iterable_packed import ( DPOPacker, IterablePackedDataset, Packer, - PackType, TextPacker, ) -from torchtune.data.metrics import MetricsAggregator from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION -from .test_iterable_utils import generate_ckpt -# --- Test Fixtures --- +from .test_iterable_utils import generate_ckpt @pytest.fixture @@ -34,29 +31,6 @@ def device(): return "cuda" -class DummyTextDataset(IterableDataset): - """Dummy dataset that returns tensor-based samples.""" - - def __init__(self, sample_sizes): - self._sample_sizes = sample_sizes - self._counter = 0 - - @property - def info(self) -> DatasetInfo: - """Returns dataset information.""" - return DatasetInfo(name="DummyTextDataset", weight=1.0, children=()) - - def __iter__(self): - # Reset counter for each new iteration - self._counter = 0 - for size in self._sample_sizes: - yield { - "tokens": torch.full((size,), self._counter, dtype=torch.long), - "labels": torch.full((size,), self._counter, dtype=torch.long), - } - self._counter += 1 - - class StatefulDummyTextDataset(IterableDataset, Stateful): """ A dummy text dataset that is also stateful, allowing its iteration @@ -93,28 +67,17 @@ def _base_iterator(): # If resuming, fast-forward the iterator to the correct position. if self._state_to_load: start_idx = self._state_to_load.get("sample_idx", 0) - logging.info( - f"StatefulDummyTextDataset.__iter__(): Resuming. Fast-forwarding iterator to index {start_idx}." - ) self._state_to_load = None - # Fast-forward the iterator to the sample index from the checkpoint. + # Consume and discard samples until the desired start point. for _ in range(start_idx): - next( - iterator, None - ) # Consume and discard samples until the desired start point. + next(iterator, None) yield from iterator def state_dict(self) -> dict[str, Any]: - logging.info( - f"StatefulDummyTextDataset.state_dict(): current state is {self._active_iterator_state}" - ) return self._active_iterator_state def load_state_dict(self, state_dict: dict[str, Any]) -> None: - logging.info( - f"StatefulDummyTextDataset.load_state_dict(): state to load is {state_dict}" - ) self._state_to_load = state_dict @@ -311,7 +274,7 @@ def test_text_packing_workflow_two_packs(self): sample_sizes = [3, 2, 4] target_tokens = 6 - dataset = DummyTextDataset(sample_sizes) + dataset = StatefulDummyTextDataset(sample_sizes) text_packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=text_packer, target_tokens_per_pack=target_tokens @@ -645,7 +608,9 @@ class TestCollatedPacked: def test_collate_empty_batch(self): """Test collating an empty batch""" - result = collate_packed(batch=[], mask_fn=lambda doc_ids, device: None, device="cpu") + result = collate_packed( + batch=[], mask_fn=lambda doc_ids, device: None, device="cpu" + ) assert result == {} def test_collate_basic_batch(self): @@ -658,9 +623,9 @@ def test_collate_basic_batch(self): "document_ids": torch.tensor([0, 0, 1]), "input_pos": torch.tensor([0, 1, 0]), "metrics": [ - type('Metric', (), {'metric_name': 'test', 'value': 1.0})(), - type('Metric', (), {'metric_name': 'test2', 'value': 2.0})() - ] + type("Metric", (), {"metric_name": "test", "value": 1.0})(), + type("Metric", (), {"metric_name": "test2", "value": 2.0})(), + ], }, { "tokens": torch.tensor([7, 8]), @@ -668,30 +633,30 @@ def test_collate_basic_batch(self): "document_ids": torch.tensor([2, 2]), "input_pos": torch.tensor([0, 1]), "metrics": [ - type('Metric', (), {'metric_name': 'test3', 'value': 3.0})() - ] - } + type("Metric", (), {"metric_name": "test3", "value": 3.0})() + ], + }, ] - + # Mock mask function def mock_mask_fn(doc_ids, device): batch_size, seq_len = doc_ids.shape return torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool) - + result = collate_packed(batch, mock_mask_fn, "cpu") - + # Check tensor stacking expected_tokens = torch.stack([torch.tensor([1, 2, 3]), torch.tensor([7, 8])]) expected_labels = torch.stack([torch.tensor([4, 5, 6]), torch.tensor([9, 10])]) expected_doc_ids = torch.stack([torch.tensor([0, 0, 1]), torch.tensor([2, 2])]) - + torch.testing.assert_close(result["tokens"], expected_tokens) torch.testing.assert_close(result["labels"], expected_labels) torch.testing.assert_close(result["document_ids"], expected_doc_ids) - + # Check metrics flattening assert len(result["metrics"]) == 3 # All metrics from both samples - + # Check mask creation assert "mask" in result assert result["mask"].shape == (2, 3, 3) # batch_size=2, seq_len=3 @@ -700,12 +665,12 @@ def test_collate_different_keys_error(self): """Test that different keys across samples raises ValueError""" batch = [ {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])}, - {"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])} + {"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])}, ] - + def mock_mask_fn(doc_ids, device): return torch.ones(1, 1, 1) - + with pytest.raises(ValueError, match="All samples must have the same keys"): collate_packed(batch, mock_mask_fn, "cpu") @@ -716,28 +681,28 @@ def test_collate_mixed_tensor_non_tensor(self): "tokens": torch.tensor([1, 2]), "document_ids": torch.tensor([0, 0]), "text_data": "sample1", - "metrics": ["DummyMetric1"] + "metrics": ["DummyMetric1"], }, { "tokens": torch.tensor([3, 4]), "document_ids": torch.tensor([1, 1]), "text_data": "sample2", - "metrics": ["DummyMetric2"] - } + "metrics": ["DummyMetric2"], + }, ] - + def mock_mask_fn(doc_ids, device): return torch.ones(2, 2, 2) - + result = collate_packed(batch, mock_mask_fn, "cpu") - + # Tensors should be stacked expected_tokens = torch.stack([torch.tensor([1, 2]), torch.tensor([3, 4])]) torch.testing.assert_close(result["tokens"], expected_tokens) - + # Non-tensors should be kept as lists assert result["text_data"] == ["sample1", "sample2"] - + # Metrics should be flattened assert result["metrics"] == ["DummyMetric1", "DummyMetric2"] @@ -753,7 +718,7 @@ def test_buffer_efficiency(self): target_tokens = 6 # With large buffer: can see all samples and pick best fit [3,1,2], [4] - dataset1 = DummyTextDataset(sample_sizes) + dataset1 = StatefulDummyTextDataset(sample_sizes) packer1 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset1 = IterablePackedDataset( dataset=dataset1, @@ -764,7 +729,7 @@ def test_buffer_efficiency(self): packs_buffered = list(packed_dataset1) # With small buffer: greedy first-fit [3], [4,1], [2] - dataset2 = DummyTextDataset(sample_sizes) + dataset2 = StatefulDummyTextDataset(sample_sizes) packer2 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset2 = IterablePackedDataset( dataset=dataset2, @@ -791,7 +756,7 @@ def test_oversized_sample_dropping(self): sample_sizes = [3, 10, 2, 8, 1] # 10 and 8 are oversized for target=6 target_tokens = 5 - dataset = DummyTextDataset(sample_sizes) + dataset = StatefulDummyTextDataset(sample_sizes) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=target_tokens @@ -877,7 +842,7 @@ def create_loader_and_aggregator(): def test_multiple_iterations_same_dataset(self): """Test that multiple iterations over same packed dataset work correctly""" sample_sizes = [2, 3, 1] - dataset = DummyTextDataset(sample_sizes) + dataset = StatefulDummyTextDataset(sample_sizes) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=4 @@ -907,7 +872,7 @@ def test_scenarios( self, sample_sizes, target_tokens, buffer_size, expected_packs, scenario ): """Parametrized edge case testing""" - dataset = DummyTextDataset(sample_sizes) + dataset = StatefulDummyTextDataset(sample_sizes) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, From 4c505e0390a8c8d254d08ab9ee885d5da81c7395 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 7 Jul 2025 21:58:56 -0700 Subject: [PATCH 48/48] improve packed testing --- .../datasets/test_iterable_packed_dataset.py | 452 +++++++++--------- torchtune/datasets/_iterable_packed.py | 12 + 2 files changed, 250 insertions(+), 214 deletions(-) diff --git a/tests/torchtune/datasets/test_iterable_packed_dataset.py b/tests/torchtune/datasets/test_iterable_packed_dataset.py index 71ed1efb30..86ff963855 100644 --- a/tests/torchtune/datasets/test_iterable_packed_dataset.py +++ b/tests/torchtune/datasets/test_iterable_packed_dataset.py @@ -4,23 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json from functools import partial -from typing import Any, Iterator, Optional +from itertools import islice +from pathlib import Path +from typing import Any import pytest import torch -from torch.utils.data import IterableDataset -from torchdata.stateful_dataloader import Stateful, StatefulDataLoader +from torchdata.stateful_dataloader import StatefulDataLoader from torchtune.data._collate import collate_packed from torchtune.data.metrics import MetricsAggregator -from torchtune.datasets._iterable_base import DatasetInfo from torchtune.datasets._iterable_packed import ( DPOPacker, IterablePackedDataset, Packer, TextPacker, ) +from torchtune.datasets._sft import HfIterableDataset from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION from .test_iterable_utils import generate_ckpt @@ -31,72 +33,53 @@ def device(): return "cuda" -class StatefulDummyTextDataset(IterableDataset, Stateful): - """ - A dummy text dataset that is also stateful, allowing its iteration - progress to be saved and loaded. Returns tensor-based samples. - """ - - def __init__(self, sample_sizes: list[int]): - self.sample_sizes = sample_sizes - self._state_to_load: Optional[dict[str, Any]] = None - # The state is the index of the *next* sample to be processed. - self._active_iterator_state: dict[str, Any] = {"sample_idx": 0} - - @property - def info(self) -> DatasetInfo: - """Returns dataset information.""" - return DatasetInfo(name="StatefulDummyTextDataset", weight=1.0, children=()) - - def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: - # This base generator yields all samples from the beginning. - def _base_iterator(): - for i, size in enumerate(self.sample_sizes): - self._active_iterator_state = {"sample_idx": i} - yield { - "tokens": torch.full( - (size,), i, dtype=torch.long - ), # Use sample index as token value - "labels": torch.full((size,), i, dtype=torch.long), - } - # After iterating, the next sample index is out of bounds - self._active_iterator_state = {"sample_idx": len(self.sample_sizes)} - - iterator = _base_iterator() - - # If resuming, fast-forward the iterator to the correct position. - if self._state_to_load: - start_idx = self._state_to_load.get("sample_idx", 0) - self._state_to_load = None - # Consume and discard samples until the desired start point. - for _ in range(start_idx): - next(iterator, None) - - yield from iterator - - def state_dict(self) -> dict[str, Any]: - return self._active_iterator_state +# --- Test Utilities --- - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self._state_to_load = state_dict +def create_test_json_file(path: Path, samples: list[dict[str, list[int]]]) -> None: + """Creates a dummy JSON test data file.""" + with open(path, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") -class DummyDPODataset(IterableDataset): - """Dummy DPO dataset that returns tensor-based samples.""" - def __init__(self, samples): - self._samples = samples +class ToTensorTransform: + """Converts lists in a sample to tensors, as expected by the packer.""" - @property - def info(self) -> DatasetInfo: - """Returns dataset information.""" - return DatasetInfo(name="DummyDPODataset", weight=1.0, children=()) + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + output = {} + for k, v in sample.items(): + if isinstance(v, list): + output[k] = torch.tensor(v, dtype=torch.long) + else: + output[k] = v + # TextPacker expects "tokens" and "labels". + if "tokens" in output and "labels" not in output: + output["labels"] = output["tokens"].clone() + return output - def __iter__(self): - yield from self._samples +@pytest.fixture +def dataset_factory(tmp_path): + """Factory for creating HfIterableDataset instances for testing.""" + + def _create_dataset( + data: list[dict[str, list[int]]], + **kwargs, + ) -> HfIterableDataset: + file_path = tmp_path / "data.json" + create_test_json_file(file_path, data) + return HfIterableDataset( + path="json", + data_files=str(file_path), + split="train", + shuffle_buffer_size=0, + output_transform=ToTensorTransform(), + num_shards_per_rank=1, + **kwargs, + ) -# --- Test Utilities --- + return _create_dataset def create_dense_mask_from_mask_mod( @@ -268,19 +251,26 @@ def test_text_causal_mask(self, device): actual_mask = create_dense_mask_from_mask_mod(text_packer, batch_doc_ids) torch.testing.assert_close(actual_mask, expected_mask) - def test_text_packing_workflow_two_packs(self): + def test_text_packing_workflow_two_packs(self, dataset_factory): """Test complete text workflow that creates exactly 2 packs with multiple samples""" # Design: Pack1=[3,2], Pack2=[4] to create 2 packs - sample_sizes = [3, 2, 4] + samples = [ + {"tokens": [0] * 3}, + {"tokens": [1] * 2}, + {"tokens": [2] * 4}, + ] target_tokens = 6 - dataset = StatefulDummyTextDataset(sample_sizes) + dataset = dataset_factory(samples) text_packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( - dataset=dataset, packer=text_packer, target_tokens_per_pack=target_tokens + dataset=dataset, + packer=text_packer, + target_tokens_per_pack=target_tokens, + buffer_size=1, ) - packs = list(packed_dataset) + packs = list(islice(packed_dataset, 2)) assert len(packs) == 2 # Pack 1: samples 0(size 3) + 1(size 2) + padding(1) @@ -543,40 +533,42 @@ def test_dpo_specialized_mask(self, device): actual_mask = create_dense_mask_from_mask_mod(dpo_packer, batch_doc_ids) torch.testing.assert_close(actual_mask, expected_mask) - def test_dpo_packing_workflow_two_packs(self): + def test_dpo_packing_workflow_two_packs(self, dataset_factory): """Test complete DPO workflow that creates exactly 2 packs with multiple samples""" samples = [ { # Sample 0: total 4 tokens (1+1+2) - "prompt_ids": torch.tensor([1]), - "chosen_response_only_ids": torch.tensor([2]), - "chosen_response_only_labels": torch.tensor([2]), - "rejected_response_only_ids": torch.tensor([3, 4]), - "rejected_response_only_labels": torch.tensor([3, 4]), + "prompt_ids": [1], + "chosen_response_only_ids": [2], + "chosen_response_only_labels": [2], + "rejected_response_only_ids": [3, 4], + "rejected_response_only_labels": [3, 4], }, { # Sample 1: total 5 tokens (2+1+2) - "prompt_ids": torch.tensor([5, 6]), - "chosen_response_only_ids": torch.tensor([7]), - "chosen_response_only_labels": torch.tensor([7]), - "rejected_response_only_ids": torch.tensor([8, 9]), - "rejected_response_only_labels": torch.tensor([8, 9]), + "prompt_ids": [5, 6], + "chosen_response_only_ids": [7], + "chosen_response_only_labels": [7], + "rejected_response_only_ids": [8, 9], + "rejected_response_only_labels": [8, 9], }, { # Sample 2: total 6 tokens (2+2+2) - "prompt_ids": torch.tensor([10, 11]), - "chosen_response_only_ids": torch.tensor([12, 13]), - "chosen_response_only_labels": torch.tensor([12, 13]), - "rejected_response_only_ids": torch.tensor([14, 15]), - "rejected_response_only_labels": torch.tensor([14, 15]), + "prompt_ids": [10, 11], + "chosen_response_only_ids": [12, 13], + "chosen_response_only_labels": [12, 13], + "rejected_response_only_ids": [14, 15], + "rejected_response_only_labels": [14, 15], }, ] - dataset = DummyDPODataset(samples) + dataset = dataset_factory(samples) dpo_packer = DPOPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( - dataset=dataset, packer=dpo_packer, target_tokens_per_pack=10 + dataset=dataset, packer=dpo_packer, target_tokens_per_pack=10, buffer_size=1 ) - packs = list(packed_dataset) - assert len(packs) == 2 # Pack1: samples 0+1 (4+5=9), Pack2: sample 2 (6) + packs = list(islice(packed_dataset, 2)) + assert ( + len(packs) == 2 + ) # Pack1: samples 0+1 (4+5=9), Pack2: sample 2+0 from cycle (6+4=10) # Pack 1: samples 0+1 (9 tokens) + padding (1) pack1 = packs[0] @@ -586,10 +578,10 @@ def test_dpo_packing_workflow_two_packs(self): non_padding_1 = (pack1["tokens"] != 999).sum() assert non_padding_1 == 9 - # Pack 2: sample 2 (6 tokens) + padding (4) + # Pack 2: sample 2 (6 tokens) + sample 0 from cycle (4 tokens) = 10 tokens (no padding) pack2 = packs[1] non_padding_2 = (pack2["tokens"] != 999).sum() - assert non_padding_2 == 6 + assert non_padding_2 == 10 # Verify masks are mutually exclusive chosen_and_rejected_1 = ( @@ -615,7 +607,7 @@ def test_collate_empty_batch(self): def test_collate_basic_batch(self): """Test basic collation functionality""" - # Create mock packed samples + # Create mock packed samples with same tensor sizes (as expected from IterablePackedDataset) batch = [ { "tokens": torch.tensor([1, 2, 3]), @@ -628,10 +620,10 @@ def test_collate_basic_batch(self): ], }, { - "tokens": torch.tensor([7, 8]), - "labels": torch.tensor([9, 10]), - "document_ids": torch.tensor([2, 2]), - "input_pos": torch.tensor([0, 1]), + "tokens": torch.tensor([7, 8, 999]), # Padded to same size + "labels": torch.tensor([9, 10, -100]), # Padded to same size + "document_ids": torch.tensor([2, 2, 3]), # Padded to same size + "input_pos": torch.tensor([0, 1, 0]), # Padded to same size "metrics": [ type("Metric", (), {"metric_name": "test3", "value": 3.0})() ], @@ -646,9 +638,15 @@ def mock_mask_fn(doc_ids, device): result = collate_packed(batch, mock_mask_fn, "cpu") # Check tensor stacking - expected_tokens = torch.stack([torch.tensor([1, 2, 3]), torch.tensor([7, 8])]) - expected_labels = torch.stack([torch.tensor([4, 5, 6]), torch.tensor([9, 10])]) - expected_doc_ids = torch.stack([torch.tensor([0, 0, 1]), torch.tensor([2, 2])]) + expected_tokens = torch.stack( + [torch.tensor([1, 2, 3]), torch.tensor([7, 8, 999])] + ) + expected_labels = torch.stack( + [torch.tensor([4, 5, 6]), torch.tensor([9, 10, -100])] + ) + expected_doc_ids = torch.stack( + [torch.tensor([0, 0, 1]), torch.tensor([2, 2, 3])] + ) torch.testing.assert_close(result["tokens"], expected_tokens) torch.testing.assert_close(result["labels"], expected_labels) @@ -709,16 +707,20 @@ def mock_mask_fn(doc_ids, device): @pytest.mark.skipif(not _SUPPORTS_FLEX_ATTENTION, reason="Flex attention not supported") class TestIterablePackedDataset: - """Test IterablePackedDataset functionality - buffer efficiency, checkpointing, edge cases""" - - def test_buffer_efficiency(self): - """Test buffer improves packing efficiency""" - # Test case where buffer helps vs hurts - order matters for first-fit - sample_sizes = [3, 4, 1, 2] # Total 10 tokens + def test_buffer(self, dataset_factory): + """Test buffer behaves as expected, i.e. when next sentence doesn't fit, goes over + the buffer and see if something fits""" + samples = [ + {"tokens": [0] * 3}, # Sample 0: size 3 + {"tokens": [1] * 4}, # Sample 1: size 4 + {"tokens": [2] * 1}, # Sample 2: size 1 + {"tokens": [3] * 2}, # Sample 3: size 2 + ] target_tokens = 6 - # With large buffer: can see all samples and pick best fit [3,1,2], [4] - dataset1 = StatefulDummyTextDataset(sample_sizes) + # Test 1: Large buffer - can see all samples and optimize packing + # Expected: [0,0,0,2,3,3] (sizes 3+1+2=6), [1,1,1,1,999,999] (size 4+2 padding) + dataset1 = dataset_factory(samples) packer1 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset1 = IterablePackedDataset( dataset=dataset1, @@ -726,10 +728,27 @@ def test_buffer_efficiency(self): target_tokens_per_pack=target_tokens, buffer_size=10, ) - packs_buffered = list(packed_dataset1) + packs_buffered = list(islice(packed_dataset1, 2)) + assert len(packs_buffered) == 2 + + # Pack 1: samples 0+2+3 (3+1+2=6) - perfect fit + pack1 = packs_buffered[0] + expected_tokens_1 = torch.tensor([0, 0, 0, 2, 3, 3]) + expected_doc_ids_1 = torch.tensor([0, 0, 0, 1, 2, 2]) + torch.testing.assert_close(pack1["tokens"], expected_tokens_1) + torch.testing.assert_close(pack1["document_ids"], expected_doc_ids_1) + + # Pack 2: sample 1 (4) + sample 2 (1) + sample 2 again from cycle (1) = 6 tokens exactly + pack2 = packs_buffered[1] + expected_tokens_2 = torch.tensor([1, 1, 1, 1, 2, 2]) + expected_doc_ids_2 = torch.tensor([0, 0, 0, 0, 1, 2]) + torch.testing.assert_close(pack2["tokens"], expected_tokens_2) + torch.testing.assert_close(pack2["document_ids"], expected_doc_ids_2) - # With small buffer: greedy first-fit [3], [4,1], [2] - dataset2 = StatefulDummyTextDataset(sample_sizes) + # Test 2: Small buffer - greedy first-fit packing with infinite dataset cycling + # Expected: [0,0,0,999,999,999] (size 3+3 padding), [1,1,1,1,2,999] (size 4+1+1 padding), + # [3,3,0,0,0,999] (size 2+3+1 padding) + dataset2 = dataset_factory(samples) packer2 = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset2 = IterablePackedDataset( dataset=dataset2, @@ -737,56 +756,138 @@ def test_buffer_efficiency(self): target_tokens_per_pack=target_tokens, buffer_size=1, ) - packs_unbuffered = list(packed_dataset2) + packs_unbuffered = list(islice(packed_dataset2, 3)) + assert len(packs_unbuffered) == 3 + + # Pack 1: sample 0 (3) + padding (3) + pack1_unbuf = packs_unbuffered[0] + expected_tokens_1_unbuf = torch.tensor([0, 0, 0, 999, 999, 999]) + expected_doc_ids_1_unbuf = torch.tensor([0, 0, 0, 1, 1, 1]) + torch.testing.assert_close(pack1_unbuf["tokens"], expected_tokens_1_unbuf) + torch.testing.assert_close( + pack1_unbuf["document_ids"], expected_doc_ids_1_unbuf + ) - # Buffer should create fewer packs (more efficient) - assert len(packs_buffered) < len(packs_unbuffered) - assert len(packs_buffered) == 2 # [3,1,2], [4] - assert len(packs_unbuffered) == 3 # [3], [4,1], [2] + # Pack 2: samples 1+2 (4+1=5) + padding (1) + pack2_unbuf = packs_unbuffered[1] + expected_tokens_2_unbuf = torch.tensor([1, 1, 1, 1, 2, 999]) + expected_doc_ids_2_unbuf = torch.tensor([0, 0, 0, 0, 1, 2]) + torch.testing.assert_close(pack2_unbuf["tokens"], expected_tokens_2_unbuf) + torch.testing.assert_close( + pack2_unbuf["document_ids"], expected_doc_ids_2_unbuf + ) - # Verify both preserve all tokens - total_buffered = sum((p["labels"] != -100).sum().item() for p in packs_buffered) - total_unbuffered = sum( - (p["labels"] != -100).sum().item() for p in packs_unbuffered + # Pack 3: sample 3 (2) + sample 0 from cycle (3) + padding (1) + pack3_unbuf = packs_unbuffered[2] + expected_tokens_3_unbuf = torch.tensor([3, 3, 0, 0, 0, 999]) + expected_doc_ids_3_unbuf = torch.tensor([0, 0, 1, 1, 1, 2]) + torch.testing.assert_close(pack3_unbuf["tokens"], expected_tokens_3_unbuf) + torch.testing.assert_close( + pack3_unbuf["document_ids"], expected_doc_ids_3_unbuf ) - assert total_buffered == total_unbuffered == sum(sample_sizes) - def test_oversized_sample_dropping(self): + def test_buffer_size_validation(self, dataset_factory): + """Test that buffer_size < 1 raises ValueError""" + samples = [{"tokens": [0] * 3}] + dataset = dataset_factory(samples) + + with pytest.raises(ValueError, match="Buffer size must be greater than 0"): + IterablePackedDataset( + dataset=dataset, + packer=TextPacker(padding_idx=999, ignore_idx=-100), + target_tokens_per_pack=6, + buffer_size=0, + ) + + def test_info_property(self, dataset_factory): + """Test that the info property works correctly and includes child dataset info""" + samples = [{"tokens": [0] * 3}] + dataset = dataset_factory(samples) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + + # Create packed dataset with custom name + packed_dataset = IterablePackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=6, + buffer_size=1, + dataset_name="TestPackedDataset", + ) + + # Check info property + info = packed_dataset.info + assert info.name == "TestPackedDataset" + assert info.weight == 1.0 + assert len(info.children) == 1 + + # Check child dataset info is included + child_info = info.children[0] + assert ( + child_info.name == "json_train" + ) # From HfIterableDataset auto-generated name + assert child_info.weight == 1.0 + + def test_oversized_sample_dropping(self, dataset_factory): """Test that oversized samples are dropped""" - sample_sizes = [3, 10, 2, 8, 1] # 10 and 8 are oversized for target=6 + samples = [ + {"tokens": [0] * 3}, # Kept + {"tokens": [1] * 10}, # Dropped + {"tokens": [2] * 2}, # Kept + {"tokens": [3] * 8}, # Dropped + {"tokens": [4] * 1}, # Kept + ] target_tokens = 5 - dataset = StatefulDummyTextDataset(sample_sizes) + dataset = dataset_factory(samples) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( - dataset=dataset, packer=packer, target_tokens_per_pack=target_tokens + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens, + buffer_size=1, ) - packs = list(packed_dataset) + packs = list(islice(packed_dataset, 5)) # Only samples 3, 2, 1 should be packed (oversized 10, 8 dropped) - total_packed_tokens = sum((p["labels"] != -100).sum().item() for p in packs) - expected_tokens = 3 + 2 + 1 # Only non-oversized samples - assert total_packed_tokens == expected_tokens + # Verify that only samples 0, 2, 4 are packed (samples 1, 3 were dropped) + all_tokens = torch.cat([pack["tokens"] for pack in packs]) + all_tokens = set(all_tokens.tolist()) - # Should create 2 packs: [3, 2], [1] - assert len(packs) == 2 + # Check that expected tokens are present and dropped tokens are not + expected_tokens = {0, 2, 4, 999} + assert ( + all_tokens == expected_tokens + ), f"Expected {expected_tokens}, got {all_tokens}" + + def test_checkpoint_and_resume(self, dataset_factory): + """Test checkpointing and resumption functionality using StatefulDataLoader - def test_checkpoint_and_resume(self): - """Test checkpointing and resumption functionality using StatefulDataLoader""" - sample_sizes = [3, 2, 5, 4, 1, 6] # Total 21 tokens + Note: This test verifies that the checkpoint/resume mechanism works correctly, + but does NOT expect identical batches after resumption. The IterablePackedDataset + explicitly does NOT save buffer or partial pack state, so packing may differ + after resumption due to different buffer fill patterns. This is by design. + """ + samples = [ + {"tokens": [0] * 3}, + {"tokens": [1] * 2}, + {"tokens": [2] * 5}, + {"tokens": [3] * 4}, + {"tokens": [4] * 1}, + {"tokens": [5] * 6}, + ] target_tokens_per_pack = 6 - batch_size = 2 + batch_size = 1 # Setup dataset factory def create_loader_and_aggregator(): - dataset = StatefulDummyTextDataset(sample_sizes) + dataset = dataset_factory(samples) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = IterablePackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=target_tokens_per_pack, - buffer_size=0, # No buffer for deterministic checkpointing + buffer_size=1, # Small buffer for predictable behavior ) collate_fn = partial( @@ -802,8 +903,8 @@ def create_loader_and_aggregator(): loader1, aggregator1 = create_loader_and_aggregator() loader2, aggregator2 = create_loader_and_aggregator() - steps_before_checkpoint = 3 - steps_after_checkpoint = 3 + steps_before_checkpoint = 2 + steps_after_checkpoint = 2 # Generate checkpoint and resume result = generate_ckpt( @@ -815,83 +916,6 @@ def create_loader_and_aggregator(): resume_aggregator=aggregator2, ) - # Verify that checkpointing and resumption are identical + # Verify that checkpointing and resumption work assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint assert len(result["resumed_batches"]) == steps_after_checkpoint - - for orig_batch, resumed_batch in zip( - result["post_checkpoint_batches"], result["resumed_batches"] - ): - assert orig_batch.keys() == resumed_batch.keys() - for key in orig_batch: - if isinstance(orig_batch[key], torch.Tensor): - torch.testing.assert_close( - orig_batch[key], - resumed_batch[key], - msg=f"Mismatch in batch key: {key}", - ) - else: - assert ( - orig_batch[key] == resumed_batch[key] - ), f"Mismatch in batch key: {key}" - - assert ( - result["final_metrics"] == result["resumed_metrics"] - ), "Final metrics should match" - - def test_multiple_iterations_same_dataset(self): - """Test that multiple iterations over same packed dataset work correctly""" - sample_sizes = [2, 3, 1] - dataset = StatefulDummyTextDataset(sample_sizes) - packer = TextPacker(padding_idx=999, ignore_idx=-100) - packed_dataset = IterablePackedDataset( - dataset=dataset, packer=packer, target_tokens_per_pack=4 - ) - # First iteration - packs1 = list(packed_dataset) - # Second iteration should produce same result - packs2 = list(packed_dataset) - - assert len(packs1) == len(packs2) - for p1, p2 in zip(packs1, packs2): - torch.testing.assert_close(p1["tokens"], p2["tokens"]) - torch.testing.assert_close(p1["document_ids"], p2["document_ids"]) - - @pytest.mark.parametrize( - "sample_sizes,target_tokens,buffer_size,expected_packs,scenario", - [ - ([3, 2, 4], 8, 10, 2, "basic_packing"), # Pack1: [3,2]+pad, Pack2: [4]+pad - ([4, 3], 8, 10, 1, "partial_final_pack"), # Pack1: [4,3]+pad - ([], 8, 10, 0, "empty_dataset"), - ([5], 10, 10, 1, "single_sample"), - ([5, 5, 5], 5, 10, 3, "exact_fit"), - ([2, 3, 1], 5, 1, 2, "small_target_and_buffer"), # Pack1: [2,3], Pack2: [1] - ], - ) - def test_scenarios( - self, sample_sizes, target_tokens, buffer_size, expected_packs, scenario - ): - """Parametrized edge case testing""" - dataset = StatefulDummyTextDataset(sample_sizes) - packer = TextPacker(padding_idx=999, ignore_idx=-100) - packed_dataset = IterablePackedDataset( - dataset=dataset, - packer=packer, - target_tokens_per_pack=target_tokens, - buffer_size=buffer_size, - ) - - packs = list(packed_dataset) - assert len(packs) == expected_packs, f"Failed scenario: {scenario}" - - # Verify output format consistency for all scenarios - for pack in packs: - assert pack["tokens"].shape == (target_tokens,) - assert pack["labels"].shape == (target_tokens,) - assert pack["document_ids"].shape == (target_tokens,) - assert pack["input_pos"].shape == (target_tokens,) - - # Verify no token loss - if sample_sizes: # Skip for empty dataset - total_packed = sum((p["labels"] != -100).sum().item() for p in packs) - assert total_packed == sum(sample_sizes) diff --git a/torchtune/datasets/_iterable_packed.py b/torchtune/datasets/_iterable_packed.py index 3271f53afa..737364df89 100644 --- a/torchtune/datasets/_iterable_packed.py +++ b/torchtune/datasets/_iterable_packed.py @@ -256,6 +256,10 @@ class IterablePackedDataset(InfiniteTuneIterableDataset, Generic[SampleType]): This dataset is stateful and supports checkpointing (relies on child dataset to be stateful), allowing training to be resumed seamlessly. + IMPORTANT: When resuming from a checkpoint, the buffer and partial pack are + discarded. Therefore, you won't be able to get exact matching results, but + the difference should be negligible. + Args: dataset (InfiniteTuneIterableDataset): The `InfiniteTuneIterableDataset` to pack. packer (Packer[SampleType]): The `Packer` that defines the packing @@ -266,6 +270,9 @@ class IterablePackedDataset(InfiniteTuneIterableDataset, Generic[SampleType]): cost of memory. Buffer samples are discarded if resuming from a checkpoint. Default is 100. dataset_name (str): The name of this packed dataset, used for metrics. Defaults to "IterablePackedDataset". + + Raises: + ValueError: If buffer_size or target_tokens_per_pack is not greater than 0. """ def __init__( @@ -276,6 +283,11 @@ def __init__( buffer_size: int = 100, dataset_name: str = "IterablePackedDataset", ): + if buffer_size <= 0: + raise ValueError("Buffer size must be greater than 0") + if target_tokens_per_pack <= 0: + raise ValueError("Target tokens per pack must be greater than 0") + self.dataset = dataset self.packer = packer self.target_tokens_per_pack = target_tokens_per_pack