From ef980aa03f4428050a7260e961d512be256905a1 Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Mon, 9 Jun 2025 08:19:28 -0700 Subject: [PATCH 1/4] Enable fine tuning on HPU --- src/instructlab/training/accelerator.py | 19 ++++++- src/instructlab/training/hpu_utils.py | 49 ++++++++++++++++ src/instructlab/training/main_ds.py | 57 +++++++++++++++---- src/instructlab/training/model.py | 27 ++++++++- src/instructlab/training/multipack_sampler.py | 9 ++- src/instructlab/training/token_dataset.py | 5 ++ src/instructlab/training/utils.py | 10 +++- 7 files changed, 158 insertions(+), 18 deletions(-) create mode 100644 src/instructlab/training/hpu_utils.py diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index b03c4a45..49796f16 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -3,7 +3,12 @@ from typing import Callable, Optional # Third Party -from accelerate import Accelerator as TransformersAccel +from instructlab.training.hpu_utils import is_torch_hpu_available +if is_torch_hpu_available(): + from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel +else: + from accelerate import Accelerator as TransformersAccel + from torch.utils.data import DataLoader from transformers import get_scheduler import torch @@ -124,7 +129,11 @@ def get_fsdp_config(self): from functools import partial # Third Party - from accelerate.utils import FullyShardedDataParallelPlugin + if is_torch_hpu_available(): + from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin + else: + from accelerate.utils import FullyShardedDataParallelPlugin + from peft.utils.other import fsdp_auto_wrap_policy from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload @@ -152,7 +161,7 @@ def get_fsdp_config(self): prefetch_policy = ( BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE ) - fsdp_plugin = FullyShardedDataParallelPlugin( + fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)( auto_wrap_policy=wrap_policy, limit_all_gathers=True, backward_prefetch=prefetch_policy, @@ -160,6 +169,10 @@ def get_fsdp_config(self): cpu_offload=CPUOffload(self.fsdp_cpu_offload_params), ) + if is_torch_hpu_available(): + fsdp_plugin.use_orig_params=True + fsdp_plugin.sync_module_states=True + # `use_orig_params` must be disabled when using LoRA and FSDP together # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts if self.model.lora_config is not None: diff --git a/src/instructlab/training/hpu_utils.py b/src/instructlab/training/hpu_utils.py new file mode 100644 index 00000000..15a951b7 --- /dev/null +++ b/src/instructlab/training/hpu_utils.py @@ -0,0 +1,49 @@ +import torch +from functools import lru_cache + + +@lru_cache(maxsize=None) +def is_torch_hpu_available() -> bool: + try: + import habana_frameworks.torch.core # noqa: F401 + except ImportError: + return False + return True + + +def simple_bucket(length): + """ + This bucket algorithm merely relies on the given number instead of based on + slicing the known (min, max) range for several reasons: + 1) Due to the use of the first-fit-decreasing (FFD) algorithm, the + (min, max) sequence length of each rank will be much smaller than the + (min, max) sequence length of the dataset. Bucketing on the + (min, max) sequence length of the dataset is not practical + 2) The (min, max) sequence length of a given rank is unknown until + finishing 1 epoch since the packing is done on the fly + 3) Due to the shuffling, the (min, max) sequence length of a given rank + may vary between ranks. Once the (min, max) sequence length of a + given rank changes, the bucketing also needs adjustment + + This bucket algorithm is based on the most significant set bit of the input number. + It first check what’s the most significant set bit, assuming it's bit "S", + and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size. + By default the range is divided into 16 buckets, so the bucket size will be + 2 ** (S - 4) + For example, 0b10001 will be padded to 0b10010. + This approach can limit the overhead of bucketing (at most 1/16 of the input + number) and also prevent recompilation due to a too small bucket size. + """ + l = length + msb = 0 + while l > 0: + msb += 1 + l = l // 2 + + align = (1 << (msb - 4)) if msb >= 4 else 1 + + return (length + align - 1) // align * align + + +def bucket(length): + return simple_bucket(length) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4ca638d0..ece61086 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -33,6 +33,14 @@ UserWarning, ) +from instructlab.training.hpu_utils import is_torch_hpu_available + +if is_torch_hpu_available(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.distributed.hccl + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + adapt_transformers_to_gaudi() + # Third Party from tqdm import tqdm from transformers import AutoConfig @@ -139,10 +147,19 @@ def train( total_length = float(torch.tensor([batch.pop("total_length")])) if not args.use_dolomite: for k in batch: - batch[k] = batch[k].to(local_rank) + batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank) + + hpu_args = [] + if is_torch_hpu_available(): + hpu_args = { + "use_flash_attention":True, + "lazy_mode":False, + } + output = model( **batch, use_cache=False, + **hpu_args, ) loss = output.loss log_loss = loss.detach().item() @@ -179,8 +196,14 @@ def train( elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time current_lr = accelerator.lr_scheduler.get_last_lr()[0] - cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) - cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + + if is_torch_hpu_available(): + mem_allocated = torch.hpu.memory_allocated() / (1024**3) + malloc_retries = 0 + else: + mem_allocated = torch.cuda.memory_allocated() / (1024**3) + malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + global_grad_norm = ( model.get_global_grad_norm() if hasattr(model, "get_global_grad_norm") @@ -202,8 +225,8 @@ def train( "rank": torch.distributed.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, - "cuda_mem_allocated": cuda_mem_allocated, - "cuda_malloc_retries": cuda_malloc_retries, + ("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated, + ("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries, "num_loss_counted_tokens": int(num_loss_counted_tokens), "num_tokens_rank0": int(total_length), "batch_size": int(micro_batch_size), @@ -236,7 +259,10 @@ def train( global_step += 1 if local_rank == 0: inner_pb.update(1) - torch.cuda.empty_cache() + + if not is_torch_hpu_available(): + torch.cuda.empty_cache() + if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") save_checkpoint( @@ -314,17 +340,24 @@ def main(args): args.model_type = model_conf.model_type #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + if is_torch_hpu_available(): + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) + else: + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + args.local_rank = int(os.environ["LOCAL_RANK"]) timeout = _get_collective_timeout() - if timeout is not None: - torch.distributed.init_process_group(timeout=timeout) - else: - torch.distributed.init_process_group() + backend = "hccl" if is_torch_hpu_available() else None + torch.distributed.init_process_group(backend=backend, timeout=timeout) args.global_rank = torch.distributed.get_rank() - tensor = torch.ByteTensor([False]).cuda() + + if is_torch_hpu_available(): + tensor = torch.ByteTensor([False]).to('hpu') + else: + tensor = torch.ByteTensor([False]).cuda() + torch.distributed.all_reduce(tensor) torch.distributed.barrier() diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 8002d2ba..6d9fe146 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -34,6 +34,8 @@ import torch # First Party +from instructlab.training.hpu_utils import is_torch_hpu_available + from instructlab.training.config import ( # Adjust this import if needed DistributedBackend, Optimizer, @@ -78,6 +80,14 @@ def __init__( def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" + + if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): + torch._dynamo.config.cache_size_limit = 10*1000 + torch._dynamo.config.accumulated_cache_size_limit = 20*1000 + self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False) + for layer in self.model.model.layers: + layer.compile(backend="hpu_backend", dynamic=False) + self.reconcile_tokenizer() if self.lora_config: self.model = self.prepare_peft_model( @@ -264,7 +274,11 @@ def _is_causal_lm_model(self) -> bool: bool: True if the model is a causal language model, False otherwise. """ # Third Party - return "ForCausalLM" in self.model.__class__.__name__ + if not is_torch_hpu_available(): + class_name = self.model.__class__.__name__ + else: + class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__ + return "ForCausalLM" in class_name def reconcile_tokenizer(self): if len(self.tokenizer) > self.model.config.vocab_size: @@ -320,6 +334,17 @@ def reconcile_tokenizer(self): ): self.model.config.eos_token_id = self.tokenizer.eos_token_id + if is_torch_hpu_available(): + model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model + class_name = model.__class__.__name__ + + replace_no_split_modules = { + 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] + } + + if class_name in replace_no_split_modules: + model._no_split_modules = replace_no_split_modules[class_name] + if not self._is_causal_lm_model(): raise ValueError( f"Model must be a causal language model, got {type(self.model)}" diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 6b9a4941..a48f8e12 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -34,6 +34,8 @@ import torch import torch.distributed as dist +from instructlab.training.hpu_utils import is_torch_hpu_available, bucket + def find_max_pack_len_with_padding( dataset, @@ -68,9 +70,14 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu): The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches. """ + lengths=dataset.get_lengths() + if is_torch_hpu_available(): + bucket_v = np.vectorize(bucket) + lengths = bucket_v(lengths) + sampler = MultipackDistributedBatchSampler( batch_max_length=num_tokens_per_gpu, - lengths=dataset.get_lengths(), + lengths=lengths, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), seed=seed, diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index da50be60..fbc1e904 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -13,6 +13,7 @@ from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler from instructlab.training.utils import log_rank_0, make_collate_fn +from instructlab.training.hpu_utils import is_torch_hpu_available, bucket class TokenDataset(Dataset): def __init__(self, data_path): @@ -109,6 +110,10 @@ def setup_dataloader( lengths = dataset.get_lengths() if sampler == "multipack": + if is_torch_hpu_available(): + bucket_v = np.vectorize(bucket) + lengths = bucket_v(lengths) + sampler = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, lengths=lengths, diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 9472884e..0f7888dd 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -51,6 +51,7 @@ TrainingArgs, ) from instructlab.training.model import Model +from instructlab.training.hpu_utils import is_torch_hpu_available, bucket logger = logging.getLogger("instructlab.training") @@ -275,6 +276,9 @@ def pad_collate_fn(batch): lens = np.array([len(item["input_ids"]) for item in batch]) max_len = max(lens) + if is_torch_hpu_available(): + max_len = bucket(max_len) + input_ids = torch.stack( [ F.pad( @@ -386,6 +390,7 @@ def reduce_sum_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **_deprecated_arguments if is_torch_hpu_available() else None, ) return_dict = isinstance(output, dict) @@ -794,7 +799,10 @@ def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if is_torch_hpu_available(): + torch.hpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) def save_checkpoint( From 72037a3b8a8b05a373cd22768caf4db9d05d543f Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Fri, 23 May 2025 09:27:27 -0700 Subject: [PATCH 2/4] Enable fine tuning on HPU --- src/instructlab/training/main_ds.py | 501 +++++++++++++----- src/instructlab/training/multipack_sampler.py | 5 + src/instructlab/training/setup_accelerator.py | 143 +++++ src/instructlab/training/utils.py | 324 ++++++++++- 4 files changed, 833 insertions(+), 140 deletions(-) create mode 100644 src/instructlab/training/setup_accelerator.py diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index ece61086..036ba335 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1,14 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from copy import deepcopy +from pathlib import Path import argparse import datetime +import functools import logging +import math import os +import re import subprocess import time import warnings +# Third Party +from accelerate import Accelerator + try: # Third Party from deepspeed.ops.adam import DeepSpeedCPUAdam @@ -24,8 +32,10 @@ try: # Third Party from deepspeed.ops.adam import FusedAdam + from deepspeed.runtime.zero.utils import ZeRORuntimeException except ImportError: FusedAdam = None + ZeRORuntimeException = None local_rank = int(os.getenv("LOCAL_RANK", "0")) if __name__ == "__main__" and (not local_rank or local_rank == 0): warnings.warn( @@ -42,57 +52,349 @@ adapt_transformers_to_gaudi() # Third Party +from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM +from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PreTrainedTokenizer, + get_scheduler, +) import torch import torch.distributed # First Party from instructlab.training import config -from instructlab.training.accelerator import Accelerator -from instructlab.training.config import ( - DistributedBackend, - ModelTypes, - TorchrunArgs, - TrainingArgs, -) # pylint: disable=no-name-in-module -from instructlab.training.logger import ( - propagate_package_logs, - setup_metric_logger, - setup_root_logger, -) -from instructlab.training.model import ( - CausalLMModel, - DolomiteModel, - LigerModel, - Model, - setup_optimizer, -) +from instructlab.training.config import DistributedBackend, TorchrunArgs, TrainingArgs +from instructlab.training.logger import setup_metric_logger, setup_root_logger from instructlab.training.multipack_sampler import ( find_packing_max_batch_len_and_grad_accum, ) +from instructlab.training.setup_accelerator import setup_accelerator from instructlab.training.token_dataset import setup_dataloader, setup_dataset from instructlab.training.tokenizer_utils import setup_tokenizer from instructlab.training.utils import ( StreamablePopen, + add_noisy_embeddings, + apply_gradient_checkpointing, + check_flash_attn_enabled, check_valid_train_args, + convert_loss_to_reduce_sum, + create_lora_config, + ensure_loadable_dolomite_checkpoint, load_latest_full_state, + prepare_peft_model, + prepare_universal_checkpoint_from_latest, save_checkpoint, save_hf_format_accelerate, set_random_seed, ) import instructlab.training.data_process as dp -logger = logging.getLogger(__name__) +logger = logging.getLogger("instructlab.training") + + +def setup_optimizer(args, model): + if args.distributed_training_framework == DistributedBackend.FSDP.value: + logger.info("Using AdamW optimizer") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(0.9, 0.95), + weight_decay=0.0, + ) + elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value: + # need to use this only when the CPU offload optimizer is enabled + if args.cpu_offload_optimizer: + logger.info("!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!") + optimizer = DeepSpeedCPUAdam( + model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95) + ) + else: + logger.info("Using FusedAdam optimizer") + optimizer = FusedAdam( + model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95) + ) + else: + raise ValueError( + f"Sharding framework {args.distributed_training_framework} is not supported." + ) + return optimizer + + +def setup_model( + args, tokenizer: PreTrainedTokenizer, train_loader, grad_accum, flash_enabled +): + bnb_config = None + if args.lora_r > 0 and args.lora_quant_bits == 4: + # Third Party + from transformers import BitsAndBytesConfig + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training + ) + + base_model_args = { + "pretrained_model_name_or_path": args.model_name_or_path, + "torch_dtype": torch.bfloat16, + "quantization_config": bnb_config, + } + if flash_enabled: + base_model_args["attn_implementation"] = "flash_attention_2" + + if args.use_dolomite: + with ensure_loadable_dolomite_checkpoint( + args.model_name_or_path, args.output_dir + ) as path: + base_model_args["pretrained_model_name_or_path"] = path + base_model_args["use_padding_free_transformer"] = True + model = GPTDolomiteForCausalLM.from_pretrained( + **base_model_args, + ) + elif args.use_liger: + # TODO(osilkin): we duplicate some checks here because someone may run this script through + # torchrun directly and not `run_training`. To fix this, we should eventually move everything + # to using `torch.multiprocessing` and simplify the CLI. + if args.lora_r > 0: + raise ValueError( + "Using LoRA and Liger kernels is not supported. Please use either LoRA or Liger kernels, but not both." + ) + try: + # Third Party + from liger_kernel.transformers import AutoLigerKernelForCausalLM + except ImportError as e: + raise ValueError( + "Liger kernels are not installed. Please install Liger kernels using the following command: pip install liger-kernel" + ) from e + + # NOTE: (jkunstle) we disable fused_linear_cross_entropy, even though it's a default for most of the models with LK support, + # because reduce_sum_loss requires the logits, and fused_linear_cross_entropy explicitly skips materializing them for + # performance. + model = AutoLigerKernelForCausalLM.from_pretrained( + **base_model_args, cross_entropy=True, fused_linear_cross_entropy=False + ) + else: + model = AutoModelForCausalLM.from_pretrained(**base_model_args) + + if is_torch_hpu_available(): + torch._dynamo.config.cache_size_limit = int(1e4) + torch._dynamo.config.accumulated_cache_size_limit = int(2e4) + model = torch.compile(model, backend="hpu_backend", dynamic=False) + for layer in model.model.layers: + layer.compile(backend="hpu_backend", dynamic=False) + + # store the base model args so we can recall them later if saving a LoRA model + args.base_model_args = base_model_args + + if len(tokenizer) > model.config.vocab_size: + logger.warning( + "tokenizer has %d tokens but model has %d vocab size", + len(tokenizer), + model.config.vocab_size, + ) + model.resize_token_embeddings( + int(8 * math.ceil(len(tokenizer) / 8.0)) + ) # make the vocab size multiple of 8 for sharding the embedding layer. + + # Fix any discrepancy between model and tokenizer + if ( + model.config.pad_token_id is not None + and tokenizer.pad_token_id is not None + and model.config.pad_token_id != tokenizer.pad_token_id + ): + logger.warning( + "There is a mismatch between pad token id of model (%d) and tokenizer(%d). Fixing model pad token id to be same as tokenizer's pad token id", + model.config.pad_token_id, + tokenizer.pad_token_id, + ) + model.config.pad_token_id = tokenizer.pad_token_id + if ( + model.config.bos_token_id is not None + and tokenizer.bos_token_id is not None + and model.config.bos_token_id != tokenizer.bos_token_id + ): + logger.warning( + "There is a mismatch between bos token id of model(%d) and tokenizer(%d). Fixing model bos token id to be same as tokenizer's bos token id", + model.config.bos_token_id, + tokenizer.bos_token_id, + ) + model.config.bos_token_id = tokenizer.bos_token_id + if ( + model.config.eos_token_id is not None + and tokenizer.eos_token_id + and model.config.eos_token_id != tokenizer.eos_token_id + ): + logger.warning( + "There is a mismatch between eos token id of model(%d) and tokenizer(%d). Fixing model eos token id to be same as tokenizer's eos token id", + model.config.eos_token_id, + tokenizer.eos_token_id, + ) + model.config.eos_token_id = tokenizer.eos_token_id + + if not is_torch_hpu_available(): + class_name = model.__class__.__name__ + else: + class_name = model._orig_mod.__class__.__name__ if model.__class__.__name__ == 'OptimizedModule' else model.__class__.__name__ + + replace_no_split_modules = { + 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] + } + + if class_name in replace_no_split_modules: + if model.__class__.__name__ == 'OptimizedModule': + model._orig_mod._no_split_modules = replace_no_split_modules[class_name] + else: + model._no_split_modules = replace_no_split_modules[class_name] + + if "ForCausalLM" not in class_name: + raise ValueError( + f"Model class name: {model.__class__.__name__} is not supported." + ) + + # ensure the model has any tokens which were added to the tokenizer + if tokenizer.pad_token_id is not None and model.config.pad_token_id is None: + model.config.pad_token_id = tokenizer.pad_token_id + if tokenizer.bos_token_id is not None and model.config.bos_token_id is None: + model.config.bos_token_id = tokenizer.bos_token_id + if tokenizer.eos_token_id is not None and model.config.eos_token_id is None: + model.config.eos_token_id = tokenizer.eos_token_id + + model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite) + model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha) + + # handling of gradient checkpointing + # it is handled differently for lora and full + # - with the exception of granite, which handles it + # in the later stanza + if args.lora_r > 0: + lora_config = create_lora_config(model, args) + model = prepare_peft_model( + model, + lora_config, + args.distributed_training_framework, + gradient_checkpointing=not args.use_dolomite, + ) + args.lora_config = lora_config + elif not args.use_dolomite: + model.gradient_checkpointing_enable() + + # granite gradient checkpointing is handled uniformly + # for both lora and full here + if args.use_dolomite: + block_name = model._no_split_modules[0] + apply_gradient_checkpointing( + model, + block_name=block_name, + use_reentrant=True, # this should be the HF default mode + ) + + if args.lora_r > 0: + + def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + accelerator = setup_accelerator(args, model, grad_accum) + + if is_torch_hpu_available(): + accelerator.state.fsdp_plugin.use_orig_params=True + accelerator.state.fsdp_plugin.sync_module_states=True + + if args.distributed_training_framework == DistributedBackend.FSDP.value: + model = accelerator.prepare(model) + optimizer = setup_optimizer(args, model) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.num_epochs * len(train_loader) // grad_accum, + ) + model, optimizer, _, lr_scheduler = accelerator.prepare( + model, + optimizer, + deepcopy(train_loader), + lr_scheduler, + ) + # Necessary so that Accelerate does not step once per GPU + # see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69 + lr_scheduler.split_batches = True + return model, lr_scheduler, optimizer, accelerator + + +# this function is to check if the checkpoint provided can be resumed +def maybe_resume_training(args, model): + local_rank = int(os.environ["LOCAL_RANK"]) + + # DS's loading function will not raise if fails to reload a checkpoint + # - if lora is used, then the checkpoints will only be for the adapters + # so we need to disable load_module_strict + # - load checkpoint will find the latest checkpoint + # - it will also load the optimizer and scheduler states by default + load_module_strict = args.lora_r == 0 # can only be true if lora is not used + output_dir = Path(args.output_dir) / "ds_native" + + try: + # attempt to load a regular checkpoint first + model.load_checkpoint(output_dir, load_module_strict=load_module_strict) + except ZeRORuntimeException as e: + if str(e).startswith("The checkpoint being loaded used a DP world size of"): + # if it fails with the above exception, then a universal + # checkpoint is required + + # prepare the universal checkpoint + # - by reading 'latest' to get the resumable checkpoint + prepare_universal_checkpoint_from_latest(output_dir) + + # need to do this to trigger the universal checkpoint + # loading + model._config.load_universal_checkpoint = True + + # then attempt to load again + model.load_checkpoint(output_dir, load_module_strict=load_module_strict) + + # reset to regular checkpoint loading + model._config.load_universal_checkpoint = False + else: + raise e # reraise + + # do this to figure out the last_step + latest_file = output_dir / "latest" + try: + with open(latest_file) as f: + # there is some assumption here that the ds_native + # checkpoints are tagged as _(samples_seen) + step_folder = f.read() + (samples_seen,) = re.match("\w+_(\d+)", step_folder).groups() + samples_seen = int(samples_seen) + + last_step = samples_seen // args.effective_batch_size + args.__dict__["last_step"] = last_step + if local_rank == 0: + logger.info("Found checkpoint at %d, resuming training", last_step) + except FileNotFoundError: + pass + + # we will update the start step here + return model def train( args, - model: Model, - optimizer: torch.optim.Optimizer, + model, + optimizer, + lr_scheduler, accelerator: Accelerator, + tokenizer: PreTrainedTokenizer, + train_loader: DataLoader, + grad_accum, ): model.train() @@ -103,15 +405,15 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") - batch_size = args.effective_batch_size // accelerator.grad_accum + batch_size = args.effective_batch_size // grad_accum samples_seen = 0 if hasattr(args, "samples_seen"): logger.info("Updating 'samples_seen' %d", args.samples_seen) samples_seen = args.samples_seen - if accelerator.save_samples > 0: - accelerator.save_samples = (accelerator.save_samples // batch_size) * batch_size + if args.save_samples > 0: + args.save_samples = (args.save_samples // batch_size) * batch_size logger.info("Number of samples per save: %d", args.save_samples) if args.save_samples_ds is not None: @@ -121,18 +423,18 @@ def train( global_grad_norm = None for epoch in range(args.current_epoch, args.num_epochs): if args.sampler in ("multipack"): - accelerator.train_loader.batch_sampler.set_epoch(epoch) + train_loader.batch_sampler.set_epoch(epoch) elif args.sampler in ("distributed"): - accelerator.train_loader.sampler.set_epoch(epoch) + train_loader.sampler.set_epoch(epoch) else: raise NotADirectoryError - num_epoch_steps = len(accelerator.train_loader) + num_epoch_steps = len(train_loader) if local_rank == 0: inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}") # blast through the batches in the train loader up to the last step within the epoch. - for batch in accelerator.train_loader: + for batch in train_loader: if global_step <= args.last_step: # in the case of resuming, last_step > 0 global_step += 1 @@ -186,16 +488,16 @@ def train( ) accelerator.backward(loss) - if global_step % accelerator.grad_accum == 0: + if global_step % grad_accum == 0: global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - accelerator.lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad() if local_rank == 0: elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time - current_lr = accelerator.lr_scheduler.get_last_lr()[0] + current_lr = lr_scheduler.get_last_lr()[0] if is_torch_hpu_available(): mem_allocated = torch.hpu.memory_allocated() / (1024**3) @@ -233,7 +535,7 @@ def train( "total_loss": float(log_loss / num_loss_counted_tokens), "samples_seen": samples_seen, "gradnorm": global_grad_norm, - "total_samples": len(accelerator.train_loader.dataset), + "total_samples": len(train_loader.dataset), "num_epoch_steps": num_epoch_steps, # "weight_norm": weight_norm, }, @@ -248,14 +550,22 @@ def train( args=args, accelerator=accelerator, model=model, - tokenizer=model.tokenizer, + tokenizer=tokenizer, samples_seen=samples_seen, is_lora=bool(args.lora_r), hf_format=True, ) - base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) - torch.distributed.barrier() + # if ( + # args.save_samples_ds is not None + # and global_step * batch_size % args.save_samples_ds == 0 + # ): + # save_model_ds_native( + # args, + # model, + # tokenizer, + # global_step * args.samples_per_gpu * world_size, + # ) global_step += 1 if local_rank == 0: inner_pb.update(1) @@ -269,21 +579,19 @@ def train( args=args, accelerator=accelerator, model=model, - tokenizer=model.tokenizer, + tokenizer=tokenizer, samples_seen=samples_seen, is_lora=bool(args.lora_r), full_state=args.accelerate_full_state_at_epoch, hf_format=True, epoch=epoch, ) - base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) - torch.distributed.barrier() if args.save_last: save_hf_format_accelerate( args, model, - model.tokenizer, + tokenizer, accelerator, samples_seen, is_lora=bool(args.lora_r), @@ -348,8 +656,11 @@ def main(args): args.local_rank = int(os.environ["LOCAL_RANK"]) timeout = _get_collective_timeout() - backend = "hccl" if is_torch_hpu_available() else None - torch.distributed.init_process_group(backend=backend, timeout=timeout) + init = functools.partial(torch.distributed.init_process_group, "hccl" if is_torch_hpu_available() else "nccl") + if timeout is not None: + init(timeout=timeout) + else: + init() args.global_rank = torch.distributed.get_rank() @@ -361,9 +672,7 @@ def main(args): torch.distributed.all_reduce(tensor) torch.distributed.barrier() - flash_enabled = Model.check_flash_attn_enabled( - args.disable_flash_attn, args.use_dolomite - ) + flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite) dataset = setup_dataset( args.data_path, @@ -371,46 +680,6 @@ def main(args): mock_len=args.mock_len, ) - # This model class wraps the various AutoModel classes we support - # based on model_type, and model_path -> choose auto_model - lora_config = None - - if args.lora_r > 0: - lora_config = Model.create_lora_config( - lora_target_modules=args.lora_target_modules, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - lora_r=args.lora_r, - ) - - # Create model based on type - model_class_map = { - ModelTypes.LIGER: LigerModel, - ModelTypes.DOLOMITE: DolomiteModel, - ModelTypes.CAUSALLM: CausalLMModel, - } - - # Convert string to ModelTypes enum with fallback - try: - model_type = ModelTypes(args.model_class) - except (ValueError, AttributeError): - model_type = ModelTypes.CAUSALLM - - # Get the model class with default fallback - model_class = model_class_map.get(model_type, CausalLMModel) - m = model_class( - model_path=args.model_name_or_path, - output_dir=args.output_dir, - lora_config=lora_config, - distributed_framework=DistributedBackend(args.distributed_training_framework), - tokenizer=tokenizer, - flash_enabled=flash_enabled, - noise_alpha=args.NEFTune_alpha, - lora_quant_bits=args.lora_quant_bits, - ) - - args.base_model_args = m.base_model_args - try: packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( num_gpus=torch.distributed.get_world_size(), @@ -484,45 +753,22 @@ def main(args): }, extra={"hparams": True}, ) - # accelerator does not need optimizer to init, in fact, the optimizer needs to be initialized AFTER the Accelerator - accelerator = Accelerator( - model=m, - samples_per_gpu=args.samples_per_gpu, - grad_accum=grad_accum, - train_loader=train_loader, - distributed_framework=DistributedBackend(args.distributed_training_framework), - fsdp_sharding_strategy=args.fsdp_sharding_strategy, - deepspeed_cpu_offload_optimizer=args.cpu_offload_optimizer, - deepspeed_cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, - deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, - fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, - save_samples=args.save_samples, - ) - # optimizer needs model that has been prepared by accelerator - # and then accelerator needs to be prepared AGAIN once optimizer is initialized - optimizer = setup_optimizer( - model=m, - cpu_offload=args.cpu_offload_optimizer, - name=None, # choose based on backend - learning_rate=args.learning_rate, - ) - accelerator.prepare_with_optimizer( - optimizer=optimizer, - lr_scheduler=args.lr_scheduler, - num_epochs=args.num_epochs, - num_warmup_steps=args.num_warmup_steps, + + model, lr_scheduler, optimizer, accelerator = setup_model( + args, tokenizer, train_loader, grad_accum, flash_enabled ) - # TODO: make this work more seamlessly - optimizer = accelerator.optimizer - m = accelerator.model load_latest_full_state(args=args, accelerator=accelerator) train( args, - model=m, - optimizer=optimizer, - accelerator=accelerator, + model, + optimizer, + lr_scheduler, + accelerator, + tokenizer, + train_loader, + grad_accum, ) torch.distributed.barrier() @@ -534,15 +780,6 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """ - # Set up logging first before any processing - # Enable package logging propagation before setting up loggers - propagate_package_logs(True) - setup_root_logger(train_args.log_level) - setup_metric_logger("async", None, train_args.ckpt_output_dir) - - logger = logging.getLogger("instructlab.training") - logger.info("Starting training setup...") - check_valid_train_args(train_args) # switch out generic tmpl for legacy tmpl if requested @@ -584,7 +821,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--learning_rate={train_args.learning_rate}", f"--num_warmup_steps={train_args.warmup_steps}", f"--save_samples={train_args.save_samples}", - f"--log_level={train_args.log_level}", + f"--log_level=INFO", f"--max_batch_len={train_args.max_batch_len}", f"--seed={train_args.random_seed}", ] @@ -730,12 +967,6 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Maybe switch out from argparse to something smarter parser = argparse.ArgumentParser() parser.add_argument("--model_name_or_path", type=str) - parser.add_argument( - "--model-class", - type=str, - default=ModelTypes.CAUSALLM.value, - help=f"valid model classes are {[x.value for x in ModelTypes]}.", - ) parser.add_argument("--data_path", type=str) parser.add_argument("--output_dir", type=str) parser.add_argument("--num_epochs", type=int, default=1) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index a48f8e12..28c87334 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -402,6 +402,11 @@ def generate_batches(self, set_stats=False): ) lengths = self.lengths[indices] + + if is_torch_hpu_available(): + bucket_v = np.vectorize(bucket) + lengths = bucket_v(lengths) + lengths_cumsum = np.cumsum(lengths) batches, total_used, total_slots = allocate( diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py new file mode 100644 index 00000000..58a64a9d --- /dev/null +++ b/src/instructlab/training/setup_accelerator.py @@ -0,0 +1,143 @@ +# Standard +from functools import partial + +# Third Party +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import PreTrainedModel +import torch + +# First Party +from instructlab.training.config import DeepSpeedOptions +from instructlab.training.utils import get_module_class_from_name, patch_target_module +from instructlab.training.hpu_utils import is_torch_hpu_available + +if is_torch_hpu_available(): + from optimum.habana.accelerate import GaudiAccelerator +else: + from accelerate import Accelerator + + +def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions): + # Third Party + from accelerate.utils import DeepSpeedPlugin + + ds_config = { + "train_batch_size": samples_per_gpu * world_size * grad_accum, + "gradient_accumulation_steps": grad_accum, + "train_micro_batch_size_per_gpu": samples_per_gpu, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + # this option is only supported with DeepSpeed ZeRO stage 3 + "offload_param": {"device": "none"}, + "offload_optimizer": {"device": "none"}, + }, + "bf16": {"enabled": True}, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + if opts.cpu_offload_optimizer: + # this only works when the cpu offload optimizer is enabled + ds_config["zero_optimization"]["offload_optimizer"] = { + # CPU offloading is the only option available in ZeRO stage 2 + "device": "cpu", + "pin_memory": opts.cpu_offload_optimizer_pin_memory, + "ratio": opts.cpu_offload_optimizer_ratio, + } + ds_plugin = DeepSpeedPlugin( + hf_ds_config=ds_config, + ) + return ds_plugin + + +def get_fsdp_config(args, model: PreTrainedModel): + # Third Party + if is_torch_hpu_available(): + from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin + else: + from accelerate.utils import FullyShardedDataParallelPlugin + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + + is_lora = args.lora_r > 0 + block_name = model._no_split_modules[0] + + wrap_policy = None + if is_lora > 0: + wrap_policy = fsdp_auto_wrap_policy(model) + else: + wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + get_module_class_from_name(model, block_name), + }, + ) + + # TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA + # We should have this be configurable in the future. + prefetch_policy = ( + BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE + ) + fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)( + auto_wrap_policy=wrap_policy, + limit_all_gathers=True, + backward_prefetch=prefetch_policy, + sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy], + cpu_offload=CPUOffload(args.cpu_offload_params_fsdp), + ) + + # `use_orig_params` must be disabled when using LoRA and FSDP together + # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts + if args.lora_r > 0: + fsdp_plugin.use_orig_params = False + + return fsdp_plugin + + +def setup_accelerator(args, model: PreTrainedModel, grad_accum): + if args.distributed_training_framework == "deepspeed": + try: + # Third Party + from deepspeed import DeepSpeedEngine + except ImportError as exc: + raise ImportError( + "DeepSpeed selected as distributed framework, but not installed" + ) from exc + + # patch deepspeed to work with quantized models. + if args.lora_quant_bits is not None: + patch_target_module( + "deepspeed.DeepSpeedEngine", + partial(DeepSpeedEngine, dont_change_device=True), + ) + + accel_args = { + "deepspeed_plugin": get_ds_plugin( + world_size=torch.distributed.get_world_size(), + samples_per_gpu=args.samples_per_gpu, + grad_accum=grad_accum, + opts=DeepSpeedOptions( + cpu_offload_optimizer=args.cpu_offload_optimizer, + cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, + cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, + save_samples=args.save_samples_ds, + ), + ), + } + elif args.distributed_training_framework == "fsdp": + accel_args = { + "fsdp_plugin": get_fsdp_config(args, model), + "mixed_precision": "bf16", + } + else: + raise ValueError( + f"Unknown sharding framework: {args.distributed_training_framework}" + ) + accelerator = (GaudiAccelerator if is_torch_hpu_available() else Accelerator)( + **accel_args, + ) + accelerator.even_batches = False + return accelerator diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 0f7888dd..db37de87 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -39,7 +39,7 @@ from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType -from transformers import AutoModelForCausalLM, PreTrainedTokenizer +from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer import numpy as np import torch import torch.nn.functional as F @@ -50,7 +50,6 @@ QuantizeDataType, TrainingArgs, ) -from instructlab.training.model import Model from instructlab.training.hpu_utils import is_torch_hpu_available, bucket logger = logging.getLogger("instructlab.training") @@ -103,9 +102,7 @@ def check_valid_train_args(train_args: TrainingArgs): "Quantization is not supported when training LoRA models with FSDP. For quantized LoRA training, please switch to DeepSpeed." ) - if Model.check_flash_attn_enabled( - train_args.disable_flash_attn, train_args.use_dolomite - ): + if check_flash_attn_enabled(train_args.disable_flash_attn, train_args.use_dolomite): # verify that the flash_attn package is actually installed try: # pylint: disable=unused-import @@ -212,6 +209,37 @@ def listen(self): break +def supports_flash_attention(device_id=0): + if is_torch_hpu_available(): + return False + + """Check if a GPU supports FlashAttention.""" + major, minor = torch.cuda.get_device_capability(device_id) + # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0) + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0] + is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942") + return is_sm8x or is_sm90 or is_compat_amd + + +def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool: + if not disable_flash_attn: + if supports_flash_attention(): + flash_enabled = True + else: + raise RuntimeError( + "ERROR: Trying to use Flash Attention on unsupported hardware. Please set disable_flash_attn to True." + ) + elif use_dolomite: + raise RuntimeError( + "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported" + ) + else: + flash_enabled = False + return flash_enabled + + def make_collate_fn( pad_token_id, use_dolomite=False, flash_enabled=True, max_batch_len=60000 ): @@ -453,6 +481,47 @@ def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: return False +def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig": + # if lora + # Third Party + from peft import LoraConfig + + # ensure we select only the modules that exist in the model + proj_layers = get_projection_layer_names(model) + if not args.lora_target_modules: + warnings.warn( + "lora_target_modules was not specified, defaulting to all of the model's projection modules" + ) + if not proj_layers: + raise RuntimeError("could not find any projection layers in the model") + args.__dict__["lora_target_modules"] = proj_layers + else: + # when the user specifies the module, we should verify that they align with what's in the model + lora_target_modules_set = set(args.lora_target_modules) + diff = lora_target_modules_set - set(proj_layers) + layers_to_target = lora_target_modules_set - diff + if len(diff) == len(args.lora_target_modules): + raise ValueError( + f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically." + ) + if diff: + warnings.warn( + "the following modules were targeted for LoRA but are not present in the model: %s. Applying LoRA only to %s modules.", + list(diff), + list(layers_to_target), + ) + args.__dict__["lora_target_modules"] = list(layers_to_target) + + return LoraConfig( + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + r=args.lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=args.lora_target_modules, + ) + + def save_fsdp_lora_model( args: Namespace, model: FSDP, @@ -517,6 +586,201 @@ def save_fsdp_lora_model( dist.barrier() +def prepare_peft_model( + model: PreTrainedModel, + peft_config, + distributed_backend: str, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": True}, + mixed_precision="bf16", +): + # will guard this + # Third Party + from peft import ( + LoraModel, + PeftConfig, + PeftModel, + get_peft_model, + prepare_model_for_kbit_training, + ) + from trl.trainer.utils import peft_module_casting_to_bf16 + + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object, " + f"and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr( + model, "is_loaded_in_4bit", False + ): + preprare_model_kwargs = { + "use_gradient_checkpointing": gradient_checkpointing + } + + # if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = ( + gradient_checkpointing_kwargs + ) + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + elif gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) + + if distributed_backend == DistributedBackend.FSDP.value: + # FSDP doesn't like `get_peft_model` as it leads to dtype mismatches + model = LoraModel(model, peft_config, "default") + else: + model = get_peft_model(model, peft_config) + if mixed_precision == "bf16" and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + + return model + + +def prepare_universal_checkpoint_from_latest(output_dir): + """Populate the universal checkpoint in output_dir/step_folder + - 1. read output_dir/latest to get step_folder + - 2. populate tmp dir in output_dir/step_folder/tmp + - 3. populate zero checkpoints in output_dir/step_folder/zero + - 4. create output_dir/latest_universal + + Items 1, 2, 3, 4 are idempotent. There is atomicity in the sense that + only after 4 is completed, then the output_dir/latest_universal + checkpoint is created in which then the universal checkpoint + can be loaded. + + Be aware that this creates an extra dir `zero/` in the checkpoint dir, + which doubles the DS checkpoint storage requirement. + - DS checkpoints store 3X model parameters in 32bit. + - e.g., will be 6X a model parameter-only checkpoint in 16bit. + + Note that this requires a latest version of deepspeed. It kind of works if + the model is not saving universal checkpoint info, but only in the + the case where advanced features like tensor parallel (TP) and + pipeline parallel (PP) are turned off. + """ + + log_rank_0( + f"\033[93mPreparing universal checkpoint in {output_dir}\033[0m", to_print=True + ) + # Third Party + from transformers.utils.import_utils import _is_package_available + + _, ds_version = _is_package_available("deepspeed", return_version=True) + if ds_version < "0.14.3": + raise ValueError("universal checkpoint only supported on deepspeed >= 0.14.3") + + start = time.time() + if torch.distributed.get_rank() == 0: + try: + # Third Party + from deepspeed.checkpoint import DeepSpeedCheckpoint + from deepspeed.checkpoint.ds_to_universal import ( + PARAM_SHAPES, + UNIVERSAL_CHECKPOINT_INFO, + _check_for_required_state, + _extract_zero_shard_files, + _merge_tp_slice_files, + _save_optimizer_state, + ) + except ImportError as exc: + raise ImportError( + "DeepSpeed-specific checkpoints cannot be saved without DeepSpeed>=0.14.3 installed" + ) from exc + + # read the latest file to get the step folder + latest_file = output_dir / "latest" + with open(latest_file) as f: + step_folder = f.read() + + # will process the checkpoint in the latest step folder + input_folder = os.path.join(output_dir, step_folder) + + # create args for the scripts below + class UniversalCheckpointArgs: + num_extract_workers: int = 1 + num_merge_workers: int = 1 + output_folder: str = input_folder # just put in same place + strict: bool = True # strict checkpoint + + args = UniversalCheckpointArgs() + + # get the checkpoint + ds_checkpoint = DeepSpeedCheckpoint(input_folder) + + # hack, force this to null if we did not properly save + # any universal checkpoint information + # - this will not support any pipeline replication and other + # replication such as TP, row parallelism, vocab, sub_params + if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: + warnings.warn( + "Universal checkpoint information not found, setting it to " + "an empty dictionary." + ) + ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} + assert ( + ds_checkpoint.tp_degree == 1 + ), "if universal checkpointing info is missing, TP must be absent" + assert ( + ds_checkpoint.pp_degree == 1 + ), "if universal checkpointing info is missing, PP must be absent" + _check_for_required_state(ds_checkpoint) + + slice_shapes = [] + for mp_rank_file in ds_checkpoint.mp_rank_files: + mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu")) + slice_shapes += mp_sd[PARAM_SHAPES] + + # fix back to normal flat dict, merge duplicates for tp>1 + slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) + temp_dir = os.path.join(args.output_folder, "tmp") + + log_rank_0( + f"\033[93m1. Extracting ZeRO fragments into {temp_dir}\033[0m", + to_print=True, + ) + _extract_zero_shard_files(args, ds_checkpoint, temp_dir) + + zero_output_folder = os.path.join(args.output_folder, "zero") + + log_rank_0( + f"\033[93m2. Merging slices into {zero_output_folder}\033[0m", to_print=True + ) + _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) + + log_rank_0( + f"\033[93m3. Saving common optimizer states into {zero_output_folder}\033[0m", + to_print=True, + ) + _save_optimizer_state(args, ds_checkpoint) + + log_rank_0( + f"\033[93m4. Removing temp directory {temp_dir}\033[0m", to_print=True + ) + shutil.rmtree(temp_dir, ignore_errors=True) + + latest_file = os.path.join(output_dir, "latest_universal") + log_rank_0(f"\033[93m5. Creating {latest_file}\033[0m", to_print=True) + with open(latest_file, "w") as f: + f.write(step_folder) + + dist.barrier() + log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds") + + @contextmanager def ensure_loadable_dolomite_checkpoint( model_name_or_path: str, @@ -794,6 +1058,44 @@ def _get_state_dict_patched(model, unwrap=False): accelerator.get_state_dict = get_state_dict_unpatched +# this is native deepspeed saving with optimizer, scheduler +def save_model_ds_native( + args, + model, + tokenizer, # pylint: disable=unused-argument + samples_seen, +): + # to get a statedict from a zero checkpoint, all you need to do is + # - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt') + # - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct) + + log_rank_0( + f"\033[93mSaving model+optimizer+scheduler in format at samples_seen: {samples_seen}\033[0m", + to_print=True, + ) + start = time.time() + # used to save huggingface format, so we can use it for hf.from_pretrained + output_dir = Path(args.output_dir) / "ds_native" + tag = f"samples_{samples_seen}" + use_lora = args.lora_r > 0 + + # NOTE: this is a distributed save + # if its lora, we only save the adapters + # - so we exclude frozen if use_lora==True + model.save_checkpoint( + output_dir, + exclude_frozen_parameters=use_lora, + tag=tag, # this will create the subdirectory with the correct name + ) + + # for now we are not saving tokenizer, config, eg.. + # so it is not totally "HF compatible" + + log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True) + log_rank_0(f"saving took {time.time() - start} seconds") + + def set_random_seed(seed): if seed is not None: random.seed(seed) @@ -919,3 +1221,15 @@ def load_latest_full_state(args, accelerator) -> None: # previous epoch is basis for current epoch. args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 args.__dict__["samples_seen"] = training_metadata["samples_seen"] + + +def get_projection_layer_names(model: PreTrainedModel) -> List[str]: + """ + Given a pretrained model, returns all of the projection layers (matching '_proj') + """ + proj_layers = set( + name.split(".")[-1] + for name, _ in model.named_modules() + if name.endswith("_proj") + ) + return list(proj_layers) From 073006219b4b2be188e99a1d6ced2eacbaad5e12 Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Thu, 29 May 2025 15:14:07 -0700 Subject: [PATCH 3/4] Refactor bucketing, disable torch.compile by default --- src/instructlab/training/main_ds.py | 2 +- src/instructlab/training/multipack_sampler.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 036ba335..601fc29d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -182,7 +182,7 @@ def setup_model( else: model = AutoModelForCausalLM.from_pretrained(**base_model_args) - if is_torch_hpu_available(): + if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): torch._dynamo.config.cache_size_limit = int(1e4) torch._dynamo.config.accumulated_cache_size_limit = int(2e4) model = torch.compile(model, backend="hpu_backend", dynamic=False) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 28c87334..a48f8e12 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -402,11 +402,6 @@ def generate_batches(self, set_stats=False): ) lengths = self.lengths[indices] - - if is_torch_hpu_available(): - bucket_v = np.vectorize(bucket) - lengths = bucket_v(lengths) - lengths_cumsum = np.cumsum(lengths) batches, total_used, total_slots = allocate( From c9f331c5d1a12c935a4d1f6dd62816b9553077d8 Mon Sep 17 00:00:00 2001 From: Jianhong-Zhang Date: Fri, 6 Jun 2025 10:08:36 -0700 Subject: [PATCH 4/4] Fix checkpoints for Gaudi remove _orig_mod prefix from checkpoints from torch.compile trained model Signed-off-by: Jianhong-Zhang --- src/instructlab/training/hpu_utils.py | 13 +++++++++++++ src/instructlab/training/utils.py | 21 ++++++++++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/instructlab/training/hpu_utils.py b/src/instructlab/training/hpu_utils.py index 15a951b7..4a26b204 100644 --- a/src/instructlab/training/hpu_utils.py +++ b/src/instructlab/training/hpu_utils.py @@ -1,3 +1,4 @@ +import os import torch from functools import lru_cache @@ -47,3 +48,15 @@ def simple_bucket(length): def bucket(length): return simple_bucket(length) + + +def save_hpu_model(model, output_dir): + from safetensors.torch import save_file + + state_dict = model.state_dict() + remove_prefix = "_orig_mod." + clean_state_dict = { + k[len(remove_prefix) :] if k.startswith(remove_prefix) else k: v + for k, v in state_dict.items() + } + save_file(clean_state_dict, os.path.join(output_dir, "model.safetensors")) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index db37de87..cd63bca5 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -50,7 +50,11 @@ QuantizeDataType, TrainingArgs, ) -from instructlab.training.hpu_utils import is_torch_hpu_available, bucket +from instructlab.training.hpu_utils import ( + is_torch_hpu_available, + bucket, + save_hpu_model, +) logger = logging.getLogger("instructlab.training") @@ -1033,12 +1037,15 @@ def _get_state_dict_patched(model, unwrap=False): model.module.unmerge_adapter() if not is_lora: - accelerator.save_model( - model, - save_directory=output_dir, - max_shard_size="5GB", - safe_serialization=True, - ) + if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): + save_hpu_model(model, output_dir) + else: + accelerator.save_model( + model, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) if args.use_dolomite and convert_dolomite and accelerator.is_main_process: # export doesnt like the directory to exist