Skip to content

Commit aaf43de

Browse files
authored
Alpaca Dataset Updates and Fixes (#303)
1 parent f1537ee commit aaf43de

File tree

6 files changed

+288
-31
lines changed

6 files changed

+288
-31
lines changed

recipes/finetune_llm.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ def recipe(
106106
grad_scaler = GradScaler(enabled=False)
107107

108108
# ---- Load dataset, set up sampler, and dataloader ---- #
109-
ds = datasets.get_dataset(params.dataset, split="train", tokenizer=tokenizer)
109+
ds = datasets.get_dataset(
110+
params.dataset,
111+
split="train",
112+
tokenizer=tokenizer,
113+
train_on_input=params.train_on_input,
114+
)
110115
sampler = DistributedSampler(
111116
ds,
112117
num_replicas=world_size,

recipes/full_finetune.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def setup(self, params: FullFinetuneParams) -> None:
130130
# setup after both of these are initialized
131131
self._sampler, self._dataloader = self._setup_data(
132132
dataset=params.dataset,
133+
train_on_input=params.train_on_input,
133134
shuffle=params.shuffle,
134135
batch_size=params.batch_size,
135136
)
@@ -240,15 +241,20 @@ def _setup_loss(self, loss: str) -> nn.Module:
240241
return loss_fn
241242

242243
def _setup_data(
243-
self, dataset: str, shuffle: bool, batch_size: int
244+
self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
244245
) -> Tuple[DistributedSampler, DataLoader]:
245246
"""
246247
All data related setup happens here. Currently this recipe only supports the
247248
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
248249
iterable datasets and streaming datasets are not supported.
249250
"""
250251
world_size, rank = utils.get_world_size_and_rank()
251-
ds = datasets.get_dataset(dataset, split="train", tokenizer=self._tokenizer)
252+
ds = datasets.get_dataset(
253+
dataset,
254+
split="train",
255+
tokenizer=self._tokenizer,
256+
train_on_input=train_on_input,
257+
)
252258
sampler = DistributedSampler(
253259
ds,
254260
num_replicas=world_size,

recipes/params.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class FullFinetuneParams:
5757

5858
# Dataset and Sampler
5959
dataset: str = ""
60+
train_on_input: bool = True
6061
shuffle: bool = True
6162
batch_size: int = 2
6263

recipes/tests/test_finetune_llm.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ def _fetch_loss_values(self, output) -> Dict[str, float]:
4848

4949
def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]:
5050
small_test_ckpt_loss_values = {
51-
"1|1|": 10.5011,
52-
"1|2|": 10.5740,
53-
"2|1|": 10.5221,
54-
"2|2|": 10.4835,
51+
"1|1|": 10.5074,
52+
"1|2|": 10.5563,
53+
"2|1|": 10.5152,
54+
"2|2|": 10.4851,
5555
}
5656
llama2_7b_ckpt_loss_values = {
57-
"1|1|": 1.2381,
58-
"1|2|": 1.1042,
59-
"2|1|": 1.3086,
60-
"2|2|": 0.9908,
57+
"1|1|": 1.1333,
58+
"1|2|": 1.1199,
59+
"2|1|": 1.2614,
60+
"2|2|": 0.9486,
6161
}
6262
if ckpt == "small_test_ckpt":
6363
return small_test_ckpt_loss_values
@@ -79,6 +79,7 @@ def test_finetune_llm_loss(self, capsys, pytestconfig):
7979

8080
kwargs_values = {
8181
"dataset": "alpaca",
82+
"train_on_input": False,
8283
"seed": 9,
8384
"shuffle": True,
8485
"model": ckpt,
@@ -120,6 +121,7 @@ def test_finetune_errors(self, capsys, pytestconfig):
120121

121122
kwargs_values = {
122123
"dataset": "alpaca",
124+
"train_on_input": False,
123125
"seed": 9,
124126
"shuffle": True,
125127
"model": ckpt,
@@ -157,6 +159,7 @@ def test_finetune_llm_loss_refactored(self, capsys, pytestconfig):
157159

158160
kwargs_values = {
159161
"dataset": "alpaca",
162+
"train_on_input": False,
160163
"seed": 9,
161164
"shuffle": True,
162165
"model": ckpt,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from unittest.mock import patch
8+
9+
import pytest
10+
11+
from torchtune import datasets
12+
from torchtune.datasets.alpaca import CROSS_ENTROPY_IGNORE_IDX
13+
from torchtune.modules.tokenizer import Tokenizer
14+
15+
from tests.test_utils import get_assets_path
16+
17+
18+
class TestAlpacaDataset:
19+
@pytest.fixture
20+
def tokenizer(self):
21+
# m.model is a pretrained Sentencepiece model using the following command:
22+
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
23+
return Tokenizer.from_file(str(get_assets_path() / "m.model"))
24+
25+
@patch("torchtune.datasets.alpaca.load_dataset")
26+
def test_prompt_generation(self, load_dataset, tokenizer):
27+
"""
28+
Test the prompt generation based on the alpaca template is correct.
29+
"""
30+
31+
# mock the call to HF datasets
32+
load_dataset.return_value = [
33+
{
34+
"instruction": "Give three tips for staying healthy.",
35+
"input": "",
36+
"output": (
37+
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
38+
"2. Exercise regularly to keep your body active and strong."
39+
"3. Get enough sleep and maintain a consistent sleep schedule."
40+
),
41+
},
42+
{
43+
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
44+
"input": "He finnished his meal and left the resturant",
45+
"output": "He finished his meal and left the restaurant.",
46+
},
47+
]
48+
49+
# Expected prompts are taken from the "output" field in
50+
# https://huggingface.co/datasets/tatsu-lab/alpaca
51+
expected_prompts = [
52+
(
53+
"Below is an instruction that describes a task. Write a response that appropriately "
54+
"completes the request.\n\n"
55+
"### Instruction:\nGive three tips for staying healthy.\n\n"
56+
"### Response:\n"
57+
),
58+
(
59+
"Below is an instruction that describes a task, paired with an input that provides further context. "
60+
"Write a response that appropriately completes the request.\n\n"
61+
"### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n"
62+
"### Input:\nHe finnished his meal and left the resturant\n\n"
63+
"### Response:\n"
64+
),
65+
]
66+
67+
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)
68+
69+
# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data
70+
# to test the prompt generation since calling __getitem__ on the alpaca_dataset object will
71+
# return the encoded input and label
72+
for idx, sample in enumerate(alpaca_dataset._data):
73+
assert expected_prompts[idx] == alpaca_dataset._generate_prompt(
74+
sample["instruction"], sample["input"]
75+
)
76+
77+
@patch("torchtune.datasets.alpaca.load_dataset")
78+
def test_label_no_masking(self, load_dataset, tokenizer):
79+
"""
80+
Test whether the input and the labels are correctly created when the input is not masked.
81+
"""
82+
83+
# mock the call to HF datasets
84+
load_dataset.return_value = [
85+
{
86+
"instruction": "Give three tips for staying healthy.",
87+
"input": "",
88+
"output": (
89+
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
90+
"2. Exercise regularly to keep your body active and strong."
91+
"3. Get enough sleep and maintain a consistent sleep schedule."
92+
),
93+
}
94+
]
95+
96+
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)
97+
input, labels = alpaca_dataset[0]
98+
99+
assert len(input) == len(labels)
100+
assert labels[-1] == tokenizer.eos_id
101+
assert input[0] == tokenizer.bos_id
102+
assert CROSS_ENTROPY_IGNORE_IDX not in labels
103+
104+
@patch("torchtune.datasets.alpaca.load_dataset")
105+
def test_label_masking(self, load_dataset, tokenizer):
106+
"""
107+
Test whether the input and the labels are correctly created when the input is masked.
108+
"""
109+
110+
# mock the call to HF datasets
111+
load_dataset.return_value = [
112+
{
113+
"instruction": "Give three tips for staying healthy.",
114+
"input": "",
115+
"output": (
116+
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
117+
"2. Exercise regularly to keep your body active and strong."
118+
"3. Get enough sleep and maintain a consistent sleep schedule."
119+
),
120+
}
121+
]
122+
123+
alpaca_dataset = datasets.get_dataset(
124+
"alpaca", tokenizer=tokenizer, train_on_input=False
125+
)
126+
127+
# Extract the prompt and tokenize it; we'll need this to test whether we're masking the
128+
# input correctly
129+
sample = alpaca_dataset._data[0]
130+
prompt = alpaca_dataset._generate_prompt(sample["instruction"], sample["input"])
131+
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False)
132+
133+
# Generate the input and labels
134+
input, labels = alpaca_dataset[0]
135+
136+
assert len(input) == len(labels)
137+
assert labels[-1] == tokenizer.eos_id
138+
assert input[0] == tokenizer.bos_id
139+
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt)
140+
141+
@patch("torchtune.datasets.alpaca.load_dataset")
142+
def test_alpaca_clean(self, load_dataset, tokenizer):
143+
"""
144+
Test whether the input and the labels are correctly created when the input is not masked.
145+
"""
146+
147+
# mock the call to HF datasets
148+
load_dataset.return_value = [
149+
{
150+
"instruction": "Give three tips for staying healthy.",
151+
"input": "",
152+
"output": (
153+
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
154+
"2. Exercise regularly to keep your body active and strong."
155+
"3. Get enough sleep and maintain a consistent sleep schedule."
156+
),
157+
}
158+
]
159+
160+
alpaca_dataset = datasets.get_dataset(
161+
"alpaca", tokenizer=tokenizer, use_clean=True
162+
)
163+
input, labels = alpaca_dataset[0]
164+
165+
assert len(input) == len(labels)
166+
assert labels[-1] == tokenizer.eos_id
167+
assert input[0] == tokenizer.bos_id
168+
assert CROSS_ENTROPY_IGNORE_IDX not in labels

0 commit comments

Comments
 (0)