diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 24eac063..cc5cd468 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -55,6 +55,7 @@ def __init__( self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework + self._last_checkpoint_size: int | None = None bnb_config = None if lora_config and lora_config.r > 0 and lora_quant_bits == 4: # Third Party @@ -76,6 +77,14 @@ def __init__( if flash_enabled: self.base_model_args["attn_implementation"] = "flash_attention_2" + @property + def last_checkpoint_size(self) -> int | None: + return self._last_checkpoint_size + + @last_checkpoint_size.setter + def last_checkpoint_size(self, value: int): + self._last_checkpoint_size = value + def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" self.reconcile_tokenizer() diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 15dd2897..d1e86a47 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -11,6 +11,7 @@ import logging import os import random +import shutil import subprocess import sys import time @@ -462,15 +463,29 @@ def get_caller(num_frames=1): return f"In {file_name}, line {line_number}" -def log_rank_0(msg, include_caller=False, rank=None, to_print=False): +def log_rank_0( + msg, include_caller=False, rank=None, to_print=False, level=logging.INFO +) -> None: if rank is None: rank = get_rank() if is_initialized() else 0 - if rank <= 0: - if include_caller: - msg = f"{get_caller(num_frames=2)}: {msg}" - if to_print: - print(msg) - else: + if rank > 0: + return + + if include_caller: + msg = f"{get_caller(num_frames=2)}: {msg}" + + if to_print: + print(msg) + return + + match level: + case logging.WARNING: + logger.warning(msg) + case logging.ERROR: + logger.error(msg) + case logging.DEBUG: + logger.debug(msg) + case _: logger.info(msg) @@ -511,6 +526,13 @@ def skip_precheck_loops(): accelerator.get_state_dict = old_get_state +def _get_checkpoint_dir(args, samples_seen) -> Path: + subdir = ( + "last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}" + ) + return Path(args.output_dir) / "hf_format" / subdir + + def save_hf_format_accelerate( args, model, @@ -519,20 +541,15 @@ def save_hf_format_accelerate( samples_seen, is_lora=False, ): - # Build the subdirectory name - subdir = ( - "last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}" - ) + # Build the final output directory path + final_output_dir = _get_checkpoint_dir(args, samples_seen) log_rank_0( - f"\033[93mSaving model in huggingface format at: {subdir}\033[0m", + f"\033[93mSaving model in huggingface format at: {final_output_dir}\033[0m", to_print=True, ) start = time.time() - # Build the final output directory path - final_output_dir = Path(args.output_dir) / "hf_format" / subdir - output_dir = final_output_dir CONFIG_NAME = "config.json" @@ -611,6 +628,48 @@ def set_random_seed(seed): torch.cuda.manual_seed_all(seed) +def _get_checkpoint_dir_size(checkpoint_dir) -> int: + total = 0 + for dirpath, _, filenames in os.walk(checkpoint_dir): + for f in filenames: + fp = os.path.join(dirpath, f) + if os.path.isfile(fp): + total += os.path.getsize(fp) + return total + + +def check_disk_space_for_next_checkpoint( + model: Model, output_dir: Path, warn_steps_ahead: int = 3 +) -> None: + checkpoint_size = model.last_checkpoint_size + if checkpoint_size is None: + # No previous checkpoint size to estimate, do nothing. + return + + def _mb_size(num_bytes): + return f"{num_bytes / 1024 / 1024:.2f} MB" + + try: + stat = shutil.disk_usage(output_dir) + free_bytes = stat.free + needed_bytes = checkpoint_size * warn_steps_ahead + + log_rank_0( + f"Disk space info: free={_mb_size(free_bytes)}, last_checkpoint_size={_mb_size(checkpoint_size)} (output_dir={output_dir})" + ) + if free_bytes < needed_bytes: + log_rank_0( + f"Estimated free disk space ({_mb_size(free_bytes)}) is less than the estimated size of the next {warn_steps_ahead} checkpoints ({_mb_size(needed_bytes)}). " + "The next checkpoint(s) may fail due to insufficient disk space.", + level=logging.WARNING, + ) + except Exception as e: + log_rank_0( + f"Could not check disk space after checkpoint: {e}", + level=logging.ERROR, + ) + + def save_checkpoint( args, accelerator: Accelerator, @@ -622,6 +681,10 @@ def save_checkpoint( hf_format: bool = True, full_state: bool = False, ) -> None: + # Warn if disk space is low. + output_dir = Path(args.output_dir) + check_disk_space_for_next_checkpoint(model, output_dir, warn_steps_ahead=3) + if hf_format: save_hf_format_accelerate( args=args, @@ -641,6 +704,12 @@ def save_checkpoint( samples_seen=samples_seen, ) + # Track last checkpoint size. + if hf_format: + checkpoint_dir = _get_checkpoint_dir(args, samples_seen) + if checkpoint_dir.exists(): + model.last_checkpoint_size = _get_checkpoint_dir_size(checkpoint_dir) + def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): """