Skip to content

Commit 834feaa

Browse files
authored
Metric Logging, Fixes and Tests for full finetune Recipe (#304)
1 parent aaf43de commit 834feaa

File tree

5 files changed

+155
-6
lines changed

5 files changed

+155
-6
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,6 @@ cover/
184184

185185
# VSCode
186186
.vscode/
187+
188+
# wandb
189+
wandb/

recipes/full_finetune.py

+42-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import os
89
import sys
910

@@ -65,18 +66,27 @@ def __init__(self, params: FullFinetuneParams) -> None:
6566

6667
# logging attributes
6768
self._output_dir = params.output_dir
69+
self._metric_logger = utils.get_metric_logger(
70+
metric_logger_type=params.metric_logger_type,
71+
project=params.project,
72+
log_dir=params.output_dir,
73+
)
74+
self._log_every_n_steps = (
75+
params.log_every_n_steps if params.log_every_n_steps else 1
76+
)
6877

6978
# _is_rank_zero is used primarily for logging. In the future, the logger
7079
# should directly take care of this
7180
_, rank = utils.get_world_size_and_rank()
7281
self._is_rank_zero = rank == 0
7382

7483
# These are public properties which are updated by the checkpoint loader
75-
# when ``resume_from_checkpoint`` is `True`
84+
# when ``resume_from_checkpoint`` is `True` or validated in tests
7685
self.seed = utils.set_seed(seed=params.seed)
7786
self.epochs_run = 0
7887
self.total_epochs = params.epochs
7988
self.max_steps_per_epoch = params.max_steps_per_epoch
89+
self.total_training_steps = 0
8090

8191
self._resume_from_checkpoint = params.resume_from_checkpoint
8292

@@ -143,6 +153,20 @@ def setup(self, params: FullFinetuneParams) -> None:
143153
else:
144154
self._grad_scaler = GradScaler(enabled=False)
145155

156+
# Finally update the recipe state which can only be correctly set after all of the
157+
# other components have been initialized and updated.
158+
159+
# Number of training steps in each epoch depends on the number of batches produced
160+
# by the dataloader and the max_steps_per_epoch param set by the user and is used
161+
# for logging and tracking training state. This should be computed after the dataloader
162+
# has been setup
163+
steps_per_epoch = len(self._dataloader)
164+
if self.max_steps_per_epoch and self.max_steps_per_epoch < len(
165+
self._dataloader
166+
):
167+
steps_per_epoch = self.max_steps_per_epoch
168+
self.total_training_steps = self.epochs_run * steps_per_epoch
169+
146170
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
147171
"""
148172
Updates the recipe state from checkpoint.
@@ -333,6 +357,8 @@ def train(self) -> None:
333357
and idx == self.max_steps_per_epoch
334358
):
335359
break
360+
361+
self.total_training_steps += 1
336362
self._optimizer.zero_grad()
337363

338364
input_ids, labels = batch
@@ -350,13 +376,26 @@ def train(self) -> None:
350376

351377
pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
352378

379+
if self.total_training_steps % self._log_every_n_steps == 0:
380+
self._metric_logger.log_dict(
381+
{
382+
"loss": loss.item(),
383+
"lr": self._optimizer.param_groups[0]["lr"],
384+
"gpu_resources": torch.cuda.memory_allocated(),
385+
},
386+
step=self.total_training_steps,
387+
)
388+
353389
self._grad_scaler.scale(loss).backward()
354390
self._grad_scaler.step(self._optimizer)
355391
self._grad_scaler.update()
356392

357393
self.epochs_run += 1
358394
self.save_checkpoint(epoch=curr_epoch)
359395

396+
def cleanup(self) -> None:
397+
self._metric_logger.close()
398+
360399

361400
def recipe_main() -> None:
362401
"""
@@ -368,13 +407,11 @@ def recipe_main() -> None:
368407
- Overwritten by arguments from the command-line using ``TuneArgumentParser``
369408
"""
370409
parser = utils.TuneArgumentParser(
371-
description=recipe.__doc__,
410+
description=FullFinetuneParams.__doc__,
372411
formatter_class=argparse.RawDescriptionHelpFormatter,
373412
)
374413
args, _ = parser.parse_known_args()
375-
parser.log_args(args)
376414
args = vars(args)
377-
378415
recipe_params = FullFinetuneParams(**args)
379416

380417
# Env variables set by torch run; only need to initialize process group
@@ -383,6 +420,7 @@ def recipe_main() -> None:
383420
recipe = FullFinetuneRecipe(params=recipe_params)
384421
recipe.setup(params=recipe_params)
385422
recipe.train()
423+
recipe.cleanup()
386424

387425

388426
if __name__ == "__main__":

recipes/params.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class FullFinetuneParams:
8686
output_dir: str = "/tmp/full_finetune_output"
8787
metric_logger_type: str = "disk"
8888
project: Optional[str] = None
89+
log_every_n_steps: Optional[int] = None
8990

9091
def __post_init__(self):
9192
for param in fields(self):

recipes/tests/test_finetune_llm.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
import os
9+
import tempfile
810
from typing import Dict, Optional
911

1012
import pytest
@@ -152,7 +154,45 @@ def test_finetune_errors(self, capsys, pytestconfig):
152154
):
153155
finetune_llm.recipe(FullFinetuneParams(**kwargs_values))
154156

155-
def test_finetune_llm_loss_refactored(self, capsys, pytestconfig):
157+
158+
class TestFullFinetuneRecipe:
159+
def _fetch_loss_values(self, output) -> Dict[str, float]:
160+
lines = output.splitlines()
161+
loss_values = {}
162+
for line in lines:
163+
if "Loss:" in line:
164+
splits = line.split("Loss:")
165+
loss_value = float(splits[1].split(":")[0])
166+
loss_values[splits[0]] = loss_value
167+
return loss_values
168+
169+
def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]:
170+
small_test_ckpt_loss_values = {
171+
"1|1|": 10.5074,
172+
"1|2|": 10.5563,
173+
"2|1|": 10.5152,
174+
"2|2|": 10.4851,
175+
}
176+
llama2_7b_ckpt_loss_values = {
177+
"1|1|": 1.1333,
178+
"1|2|": 1.1199,
179+
"2|1|": 1.2614,
180+
"2|2|": 0.9486,
181+
}
182+
if ckpt == "small_test_ckpt":
183+
return small_test_ckpt_loss_values
184+
if ckpt == "llama2_7b":
185+
return llama2_7b_ckpt_loss_values
186+
raise ValueError(f"Unknown ckpt {ckpt}")
187+
188+
def _fetch_ckpt_model_path(self, ckpt) -> str:
189+
if ckpt == "small_test_ckpt":
190+
return "/tmp/test-artifacts/small-ckpt-01242024"
191+
if ckpt == "llama2_7b":
192+
return "/tmp/test-artifacts/llama2-7b-01242024"
193+
raise ValueError(f"Unknown ckpt {ckpt}")
194+
195+
def test_loss(self, capsys, pytestconfig):
156196
large_scale = pytestconfig.getoption("--large-scale")
157197
ckpt = "llama2_7b" if large_scale else "small_test_ckpt"
158198
expected_loss_values = self._fetch_expected_loss_values(ckpt)
@@ -195,3 +235,71 @@ def test_finetune_llm_loss_refactored(self, capsys, pytestconfig):
195235
assert key in expected_loss_values
196236
expected_loss_value = expected_loss_values[key]
197237
assert value == pytest.approx(expected_loss_value, abs=0.001)
238+
239+
def test_training_state_on_resume(self):
240+
"""
241+
Test whether the recipe state is correctly updated on resume. Since this
242+
is model agnostic, we should run this on the small model only. The test
243+
consists of two stages:
244+
- Train a model for 4 epochs
245+
- Resume training after epoch 3 and check training state.
246+
"""
247+
248+
model_ckpt = "small_test_ckpt"
249+
expected_loss_values = self._fetch_expected_loss_values(model_ckpt)
250+
251+
with tempfile.TemporaryDirectory() as tmpdirname:
252+
253+
kwargs_values = {
254+
"dataset": "alpaca",
255+
"seed": 9,
256+
"shuffle": True,
257+
"model": model_ckpt,
258+
"model_checkpoint": self._fetch_ckpt_model_path(model_ckpt),
259+
"tokenizer": "llama2_tokenizer",
260+
"tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model",
261+
"epochs": 4,
262+
"max_steps_per_epoch": 2,
263+
"output_dir": tmpdirname,
264+
"device": "cpu",
265+
"resume_from_checkpoint": False,
266+
"enable_fsdp": False,
267+
}
268+
269+
recipe_params = FullFinetuneParams(**kwargs_values)
270+
271+
recipe = FullFinetuneRecipe(recipe_params)
272+
recipe.setup(params=recipe_params)
273+
recipe.train()
274+
recipe.cleanup()
275+
276+
# In the new run, remove seed and max_steps_per_epoch and
277+
# check if these are correctly inferred from the checkpoint
278+
# Note this will raise some warnings in the logs, but is a
279+
# stronger test
280+
kwargs_values_resume = {
281+
"dataset": "alpaca",
282+
"shuffle": True,
283+
"model": model_ckpt,
284+
"model_checkpoint": os.path.join(tmpdirname, "model_2.ckpt"),
285+
"tokenizer": "llama2_tokenizer",
286+
"tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model",
287+
"epochs": 4,
288+
"output_dir": tmpdirname,
289+
"device": "cpu",
290+
"resume_from_checkpoint": True, # set to True to resume
291+
"enable_fsdp": False,
292+
}
293+
294+
recipe_params = FullFinetuneParams(**kwargs_values_resume)
295+
296+
recipe = FullFinetuneRecipe(recipe_params)
297+
recipe.setup(params=recipe_params)
298+
299+
assert recipe.epochs_run == 3
300+
assert recipe.seed == kwargs_values["seed"]
301+
assert recipe.max_steps_per_epoch == kwargs_values["max_steps_per_epoch"]
302+
assert recipe.total_epochs == kwargs_values["epochs"]
303+
assert recipe.total_training_steps == (
304+
3 * kwargs_values["max_steps_per_epoch"]
305+
)

torchtune/utils/metric_logging.py

-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def __init__(
155155
"``wandb`` package not found. Please install wandb using `pip install wandb` to use WandBLogger."
156156
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
157157
) from e
158-
159158
self._wandb = wandb
160159
self._wandb.init(
161160
project=project,

0 commit comments

Comments
 (0)