Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.noise_alpha = noise_alpha
self.tokenizer = tokenizer
self.distributed_framework = distributed_framework
self._last_checkpoint_size: int | None = None
bnb_config = None
if lora_config and lora_config.r > 0 and lora_quant_bits == 4:
# Third Party
Expand All @@ -76,6 +77,14 @@ def __init__(
if flash_enabled:
self.base_model_args["attn_implementation"] = "flash_attention_2"

@property
def last_checkpoint_size(self) -> int | None:
return self._last_checkpoint_size

@last_checkpoint_size.setter
def last_checkpoint_size(self, value: int):
self._last_checkpoint_size = value

def _post_model_init(self):
"""Common initialization steps that should happen after model initialization."""
self.reconcile_tokenizer()
Expand Down
99 changes: 84 additions & 15 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import os
import random
import shutil
import subprocess
import sys
import time
Expand Down Expand Up @@ -462,15 +463,29 @@ def get_caller(num_frames=1):
return f"In {file_name}, line {line_number}"


def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
def log_rank_0(
msg, include_caller=False, rank=None, to_print=False, level=logging.INFO
) -> None:
if rank is None:
rank = get_rank() if is_initialized() else 0
if rank <= 0:
if include_caller:
msg = f"{get_caller(num_frames=2)}: {msg}"
if to_print:
print(msg)
else:
if rank > 0:
return

if include_caller:
msg = f"{get_caller(num_frames=2)}: {msg}"

if to_print:
print(msg)
return

match level:
case logging.WARNING:
logger.warning(msg)
case logging.ERROR:
logger.error(msg)
case logging.DEBUG:
logger.debug(msg)
case _:
logger.info(msg)


Expand Down Expand Up @@ -511,6 +526,13 @@ def skip_precheck_loops():
accelerator.get_state_dict = old_get_state


def _get_checkpoint_dir(args, samples_seen) -> Path:
subdir = (
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
)
return Path(args.output_dir) / "hf_format" / subdir


def save_hf_format_accelerate(
args,
model,
Expand All @@ -519,20 +541,15 @@ def save_hf_format_accelerate(
samples_seen,
is_lora=False,
):
# Build the subdirectory name
subdir = (
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
)
# Build the final output directory path
final_output_dir = _get_checkpoint_dir(args, samples_seen)

log_rank_0(
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
f"\033[93mSaving model in huggingface format at: {final_output_dir}\033[0m",
to_print=True,
)
start = time.time()

# Build the final output directory path
final_output_dir = Path(args.output_dir) / "hf_format" / subdir

output_dir = final_output_dir

CONFIG_NAME = "config.json"
Expand Down Expand Up @@ -611,6 +628,48 @@ def set_random_seed(seed):
torch.cuda.manual_seed_all(seed)


def _get_checkpoint_dir_size(checkpoint_dir) -> int:
total = 0
for dirpath, _, filenames in os.walk(checkpoint_dir):
for f in filenames:
fp = os.path.join(dirpath, f)
if os.path.isfile(fp):
total += os.path.getsize(fp)
return total


def check_disk_space_for_next_checkpoint(
model: Model, output_dir: Path, warn_steps_ahead: int = 3
) -> None:
checkpoint_size = model.last_checkpoint_size
if checkpoint_size is None:
# No previous checkpoint size to estimate, do nothing.
return

def _mb_size(num_bytes):
return f"{num_bytes / 1024 / 1024:.2f} MB"

try:
stat = shutil.disk_usage(output_dir)
free_bytes = stat.free
needed_bytes = checkpoint_size * warn_steps_ahead

log_rank_0(
f"Disk space info: free={_mb_size(free_bytes)}, last_checkpoint_size={_mb_size(checkpoint_size)} (output_dir={output_dir})"
)
if free_bytes < needed_bytes:
log_rank_0(
f"Estimated free disk space ({_mb_size(free_bytes)}) is less than the estimated size of the next {warn_steps_ahead} checkpoints ({_mb_size(needed_bytes)}). "
"The next checkpoint(s) may fail due to insufficient disk space.",
level=logging.WARNING,
)
except Exception as e:
log_rank_0(
f"Could not check disk space after checkpoint: {e}",
level=logging.ERROR,
)


def save_checkpoint(
args,
accelerator: Accelerator,
Expand All @@ -622,6 +681,10 @@ def save_checkpoint(
hf_format: bool = True,
full_state: bool = False,
) -> None:
# Warn if disk space is low.
output_dir = Path(args.output_dir)
check_disk_space_for_next_checkpoint(model, output_dir, warn_steps_ahead=3)

if hf_format:
save_hf_format_accelerate(
args=args,
Expand All @@ -641,6 +704,12 @@ def save_checkpoint(
samples_seen=samples_seen,
)

# Track last checkpoint size.
if hf_format:
checkpoint_dir = _get_checkpoint_dir(args, samples_seen)
if checkpoint_dir.exists():
model.last_checkpoint_size = _get_checkpoint_dir_size(checkpoint_dir)


def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int):
"""
Expand Down
Loading