Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 85 additions & 37 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import is_master
from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator, ShardedDataset

try:
import wandb
Expand Down Expand Up @@ -227,75 +228,122 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
class OfflineSupervisedDataset(Dataset):
"""Lazy offline dataset for supervised fine-tuning.

This dataset loads data on-the-fly from pre-processed .pt data files as well as
input conversations in JSON format.
This dataset loads data on-the-fly from pre-processed .pt data files.

Args:
data_entries (list): A list of tuples (raw_data_example, file_path).
dumped_files (list): A list of file paths to the dumped .pt files.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""

def __init__(
self,
data_entries,
dumped_files,
tokenizer: transformers.PreTrainedTokenizer,
vlm_processor=None,
img_dir=None,
):
super().__init__()
print_rank_0("Formatting inputs...Skip in offline mode")
self.tokenizer = tokenizer
self.data_entries = data_entries
self.vlm_processor = vlm_processor
self.img_dir = img_dir
self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
self.dumped_files = dumped_files
# self.vlm_processor = vlm_processor
# self.img_dir = img_dir
# self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess

# Does not cache the hidden states, as those have an extremely large memory footprint.
self.cached_data_dict = {}

def __len__(self):
return len(self.data_entries)
return len(self.dumped_files)

def __getitem__(self, i) -> dict[str, torch.Tensor]:
# Load the conversational data, using the cache
raw_data, offline_file_path = self.data_entries[i]
if i in self.cached_data_dict:
preprocessed_base = self.cached_data_dict[i]
ret = self.cached_data_dict[i]
else:
ret = self.preprocess_fn(
[raw_data], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir
)
preprocessed_base = {k: ret[k][0] for k in ret}
self.cached_data_dict[i] = preprocessed_base

# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]

# Make sure the input_ids have the same shape
if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
msg = f"""Input IDs from offline data do not match the preprocessed input IDs
for offline data sample at {offline_file_path}."""
raise ValueError(msg)

ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache
ret["input_ids"] = offline_data["input_ids"]
ret["kwargs"] = {
"base_model_outputs": {
"base_model_hidden_states": offline_data["hidden_states"],
"aux_hidden_states": offline_data["aux_hidden_states"],
offline_file_path = self.dumped_files[i]
# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
ret = {
"input_ids": offline_data["input_ids"][:max_length],
"kwargs": {
"base_model_outputs": {
"base_model_hidden_states": offline_data["hidden_states"][:max_length, :],
"aux_hidden_states": offline_data["aux_hidden_states"][:max_length, :],
}
},
}
}
self.cached_data_dict[i] = ret
return ret


def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
max_length=None,
) -> dict:
if data_args.offline_data_path is not None:
print_rank_0("Loading pre-processed data for offline training...")

# Glob for all .pt files in the data_path directory
assert data_args.offline_data_path is not None, (
"offline_data_path must be provided for offline training."
)
offline_data_path = Path(data_args.offline_data_path)
all_files = [str(p) for p in offline_data_path.glob("*.pt")]
if not all_files:
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")

# # Filter to conversations that exist in the offline data and in the provided json
# valid_entries = []
# for entry in train_dataset:
# conv_id = entry.get("conversation_id")
# if conv_id is None:
# conv_id = entry.get("uuid")
# if conv_id is None:
# conv_id = entry.get("id")
# if conv_id is None:
# raise ValueError(f"Conversation ID required but not found for entry {entry}")
# file_path = str(offline_data_path / f"{conv_id}.pt")
# if file_path in all_files:
# valid_entries.append((entry, file_path))

# if len(valid_entries) == 0:
# msg = """No valid files found in the offline data path that match the conversation IDs
# in the provided data json. Please ensure that the offline data path is correct and
# contains .pt files named after the conversation IDs, and that the input conversations
# json has the correct format (with 'conversation_id' or 'id' fields)."""
# raise ValueError(msg)
# elif len(valid_entries) < len(data_json):
# print_rank_0(
# f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations"
# " have corresponding .pt files in the offline data path. Continuing..."
# )

train_dataset = OfflineSupervisedDataset(
all_files,
tokenizer=tokenizer,
)

data_collator = DataCollatorForOffline(max_length=max_length)
else:
train_dataset = ShardedDataset("nvidia/Daring-Anteater")
data_collator = LanguageDataCollator(
tokenizer=tokenizer,
max_length=max_length,
)

return {
"train_dataset": train_dataset,
"data_collator": data_collator,
}


def make_eagle_supervised_data_module_old(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
max_length=None,
) -> dict:
"""Make dataset and collator for supervised fine-tuning.

Expand Down
7 changes: 6 additions & 1 deletion modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,12 @@ def forward(
assert past_key_values is None, "past_key_values should be None in training"

if loss_mask is None:
loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device)
# By default, mask out padding tokens in loss computation
loss_mask = (
attention_mask.clone().detach()
if attention_mask is not None
else torch.ones_like(input_ids, dtype=torch.bool)
)

# ====First, we run base model forward====
if "base_model_outputs" in kwargs:
Expand Down
Loading