diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index afa215ae37..1b0a094d50 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -46,6 +46,8 @@ checkpointer: model_type: LLAMA2 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 77cfef59e2..72f42db691 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -48,6 +48,8 @@ checkpointer: model_type: LLAMA3 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 46b3f767ee..c09ac23307 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -48,6 +48,8 @@ checkpointer: model_type: LLAMA3 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index a5479fa724..3446743be2 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -43,6 +43,8 @@ checkpointer: model_type: LLAMA3_2 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index 4f54caed9f..1c29f0a242 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -45,6 +45,8 @@ checkpointer: model_type: LLAMA3_2 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 6b434aa499..994b0e6be6 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -48,6 +48,8 @@ checkpointer: model_type: LLAMA3_VISION resume_from_checkpoint: False save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_last_epoch_only: False +epochs_to_save: 'all' # Dataset dataset: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index a51bfef26f..7c661a5e66 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -170,6 +170,20 @@ def __init__(self, cfg: DictConfig) -> None: self.global_step = 0 self._resume_from_checkpoint = cfg.resume_from_checkpoint self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + if cfg.save_last_epoch_only and cfg.epochs_to_save: + utils.log_rank_zero( + self._logger, + "Both save_last_epoch_only and epochs_to_save are in use. " + "The value for save_last_epoch_only takes precedence but will be removed in a future release.", + ) + self._save_last_epoch_only = cfg.get("save_last_epoch_only", False) + self._epochs_to_save = ( + [self.total_epochs - 1] + if self._save_last_epoch_only + else cfg.get("epochs_to_save", "all") + ) + if self._epochs_to_save == "all": + self._epochs_to_save = list(range(self.total_epochs)) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) @@ -718,17 +732,22 @@ def train(self) -> None: break self.epochs_run += 1 - start_save_checkpoint = time.perf_counter() - self._logger.info("Starting checkpoint save...") - - # Save final non-distributed ckpt - self.save_checkpoint(epoch=curr_epoch, full_tensors=True) - self._logger.info( - "Checkpoint saved in {:.2f} seconds.".format( - time.perf_counter() - start_save_checkpoint - ) - ) - + if curr_epoch in self._epochs_to_save: + start_save_checkpoint = time.perf_counter() + self._logger.info( + f"Starting checkpoint save for epoch {curr_epoch}..." + ) + self.save_checkpoint(epoch=curr_epoch) + self._logger.info( + "Checkpoint saved in {:.2f} seconds.".format( + time.perf_counter() - start_save_checkpoint + ) + ) + else: + self._log.info( + f"Skipping checkpoint save for epoch {curr_epoch}..." + ) + def cleanup(self) -> None: self._metric_logger.close() diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index 27a62682fc..50ef1f76be 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -366,6 +366,270 @@ def test_training_state_on_resume_with_async_checkpointing( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) + @pytest.mark.parametrize( + "epochs_to_save, expected_folders", + [ + ("all", ["epoch_0", "epoch_1", "epoch_2"]), + ("none", []), + ("1,3", ["epoch_0", "epoch_2"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_epochs_to_save( + self, tmpdir, monkeypatch, epochs_to_save, expected_folders + ): + """Test that epochs_to_save parameter controls which epoch folders are saved. + The test checks if the specified epochs are saved after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save={epochs_to_save} \ + save_last_epoch_only=False \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=False \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + + assert ( + saved_epoch_folders == expected_folders + ), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "epochs_to_save, expected_folders", + [ + ("all", ["epoch_0", "epoch_1", "epoch_2"]), + ("none", []), + ("1,3", ["epoch_0", "epoch_2"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_epochs_to_save_with_async_checkpointing( + self, tmpdir, monkeypatch, epochs_to_save, expected_folders + ): + """Test that epochs_to_save parameter controls which epoch folders are saved with async checkpointing. + The test checks if the specified epochs are saved after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save={epochs_to_save} \ + save_last_epoch_only=False \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + + assert ( + saved_epoch_folders == expected_folders + ), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "save_last_epoch_only, expected_folders", + [ + (True, ["epoch_2"]), + (False, ["epoch_0", "epoch_1"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_save_last_epoch_only( + self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders + ): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs and if it correctly overrides epochs_to_save. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save='1,2' \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + + assert ( + saved_epoch_folders == expected_folders + ), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "save_last_epoch_only, expected_folders", + [ + (True, ["epoch_2"]), + (False, ["epoch_0", "epoch_1"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_save_last_epoch_only_with_async_checkpointing( + self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders + ): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior with async checkpointing. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs and if it correctly overrides epochs_to_save. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save='1,2' \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + + assert ( + saved_epoch_folders == expected_folders + ), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + @pytest.mark.parametrize("use_dora", [False, True]) @pytest.mark.integration_test @gpu_test(gpu_count=1)