|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import logging
|
| 8 | +import os |
| 9 | +import tempfile |
8 | 10 | from typing import Dict, Optional
|
9 | 11 |
|
10 | 12 | import pytest
|
@@ -152,7 +154,45 @@ def test_finetune_errors(self, capsys, pytestconfig):
|
152 | 154 | ):
|
153 | 155 | finetune_llm.recipe(FullFinetuneParams(**kwargs_values))
|
154 | 156 |
|
155 |
| - def test_finetune_llm_loss_refactored(self, capsys, pytestconfig): |
| 157 | + |
| 158 | +class TestFullFinetuneRecipe: |
| 159 | + def _fetch_loss_values(self, output) -> Dict[str, float]: |
| 160 | + lines = output.splitlines() |
| 161 | + loss_values = {} |
| 162 | + for line in lines: |
| 163 | + if "Loss:" in line: |
| 164 | + splits = line.split("Loss:") |
| 165 | + loss_value = float(splits[1].split(":")[0]) |
| 166 | + loss_values[splits[0]] = loss_value |
| 167 | + return loss_values |
| 168 | + |
| 169 | + def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]: |
| 170 | + small_test_ckpt_loss_values = { |
| 171 | + "1|1|": 10.5074, |
| 172 | + "1|2|": 10.5563, |
| 173 | + "2|1|": 10.5152, |
| 174 | + "2|2|": 10.4851, |
| 175 | + } |
| 176 | + llama2_7b_ckpt_loss_values = { |
| 177 | + "1|1|": 1.1333, |
| 178 | + "1|2|": 1.1199, |
| 179 | + "2|1|": 1.2614, |
| 180 | + "2|2|": 0.9486, |
| 181 | + } |
| 182 | + if ckpt == "small_test_ckpt": |
| 183 | + return small_test_ckpt_loss_values |
| 184 | + if ckpt == "llama2_7b": |
| 185 | + return llama2_7b_ckpt_loss_values |
| 186 | + raise ValueError(f"Unknown ckpt {ckpt}") |
| 187 | + |
| 188 | + def _fetch_ckpt_model_path(self, ckpt) -> str: |
| 189 | + if ckpt == "small_test_ckpt": |
| 190 | + return "/tmp/test-artifacts/small-ckpt-01242024" |
| 191 | + if ckpt == "llama2_7b": |
| 192 | + return "/tmp/test-artifacts/llama2-7b-01242024" |
| 193 | + raise ValueError(f"Unknown ckpt {ckpt}") |
| 194 | + |
| 195 | + def test_loss(self, capsys, pytestconfig): |
156 | 196 | large_scale = pytestconfig.getoption("--large-scale")
|
157 | 197 | ckpt = "llama2_7b" if large_scale else "small_test_ckpt"
|
158 | 198 | expected_loss_values = self._fetch_expected_loss_values(ckpt)
|
@@ -195,3 +235,71 @@ def test_finetune_llm_loss_refactored(self, capsys, pytestconfig):
|
195 | 235 | assert key in expected_loss_values
|
196 | 236 | expected_loss_value = expected_loss_values[key]
|
197 | 237 | assert value == pytest.approx(expected_loss_value, abs=0.001)
|
| 238 | + |
| 239 | + def test_training_state_on_resume(self): |
| 240 | + """ |
| 241 | + Test whether the recipe state is correctly updated on resume. Since this |
| 242 | + is model agnostic, we should run this on the small model only. The test |
| 243 | + consists of two stages: |
| 244 | + - Train a model for 4 epochs |
| 245 | + - Resume training after epoch 3 and check training state. |
| 246 | + """ |
| 247 | + |
| 248 | + model_ckpt = "small_test_ckpt" |
| 249 | + expected_loss_values = self._fetch_expected_loss_values(model_ckpt) |
| 250 | + |
| 251 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 252 | + |
| 253 | + kwargs_values = { |
| 254 | + "dataset": "alpaca", |
| 255 | + "seed": 9, |
| 256 | + "shuffle": True, |
| 257 | + "model": model_ckpt, |
| 258 | + "model_checkpoint": self._fetch_ckpt_model_path(model_ckpt), |
| 259 | + "tokenizer": "llama2_tokenizer", |
| 260 | + "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", |
| 261 | + "epochs": 4, |
| 262 | + "max_steps_per_epoch": 2, |
| 263 | + "output_dir": tmpdirname, |
| 264 | + "device": "cpu", |
| 265 | + "resume_from_checkpoint": False, |
| 266 | + "enable_fsdp": False, |
| 267 | + } |
| 268 | + |
| 269 | + recipe_params = FullFinetuneParams(**kwargs_values) |
| 270 | + |
| 271 | + recipe = FullFinetuneRecipe(recipe_params) |
| 272 | + recipe.setup(params=recipe_params) |
| 273 | + recipe.train() |
| 274 | + recipe.cleanup() |
| 275 | + |
| 276 | + # In the new run, remove seed and max_steps_per_epoch and |
| 277 | + # check if these are correctly inferred from the checkpoint |
| 278 | + # Note this will raise some warnings in the logs, but is a |
| 279 | + # stronger test |
| 280 | + kwargs_values_resume = { |
| 281 | + "dataset": "alpaca", |
| 282 | + "shuffle": True, |
| 283 | + "model": model_ckpt, |
| 284 | + "model_checkpoint": os.path.join(tmpdirname, "model_2.ckpt"), |
| 285 | + "tokenizer": "llama2_tokenizer", |
| 286 | + "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", |
| 287 | + "epochs": 4, |
| 288 | + "output_dir": tmpdirname, |
| 289 | + "device": "cpu", |
| 290 | + "resume_from_checkpoint": True, # set to True to resume |
| 291 | + "enable_fsdp": False, |
| 292 | + } |
| 293 | + |
| 294 | + recipe_params = FullFinetuneParams(**kwargs_values_resume) |
| 295 | + |
| 296 | + recipe = FullFinetuneRecipe(recipe_params) |
| 297 | + recipe.setup(params=recipe_params) |
| 298 | + |
| 299 | + assert recipe.epochs_run == 3 |
| 300 | + assert recipe.seed == kwargs_values["seed"] |
| 301 | + assert recipe.max_steps_per_epoch == kwargs_values["max_steps_per_epoch"] |
| 302 | + assert recipe.total_epochs == kwargs_values["epochs"] |
| 303 | + assert recipe.total_training_steps == ( |
| 304 | + 3 * kwargs_values["max_steps_per_epoch"] |
| 305 | + ) |
0 commit comments