Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
fsdp_cpu_offload_params: Optional[bool] = False,
fsdp_use_orig_params: Optional[bool] = False,
device: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just str, you might want to create a new type that constrains the choices between cuda and hpu:

AcceleratorDevice = literal["cuda", "hpu"]

# ...

# safe to infer cuda, since this will be the most common option
device: Optional[AcceleratorDevice] = "cuda"  

):
self.samples_per_gpu = samples_per_gpu
self.save_samples = save_samples
Expand All @@ -61,6 +62,7 @@ def __init__(
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
self.fsdp_use_orig_params = fsdp_use_orig_params
self.lr_scheduler = None
self.device_str = device #should be before first use, that happens in self.get_fsdp_config()
if self.distributed_framework == DistributedBackend.DEEPSPEED:
# Standard
accel_args = {
Expand All @@ -81,6 +83,10 @@ def __init__(
"fsdp_plugin": self.get_fsdp_config(),
"mixed_precision": "bf16",
}
if device == "hpu":
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
else:
from accelerate import Accelerator as TransformersAccel
self.accelerator = TransformersAccel(
**accel_args,
)
Expand Down Expand Up @@ -159,7 +165,9 @@ def get_fsdp_config(self):
use_orig_params=self.fsdp_use_orig_params,
# TODO(osilkin): expose switch for fp32 reduction
)

if self.device_str == "hpu":
fsdp_plugin.use_orig_params=True
fsdp_plugin.sync_module_states=True
return fsdp_plugin

def get_ds_plugin(
Expand Down
7 changes: 5 additions & 2 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BatchLossManager:
- Computing average losses for logging
"""

def __init__(self, model, accelerator, world_size: int, local_rank: int):
def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str):
"""
Initialize the BatchLossManager.

Expand All @@ -60,7 +60,10 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
self.accelerator: Accelerator = accelerator
self.world_size: int = world_size
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)
if device == "hpu":
self.torch_device = torch.device("hpu", local_rank)
else:
self.torch_device = torch.device("cuda", local_rank)

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,5 @@ class TrainingArgs(BaseModel):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO"
)

device: Optional[str] = None
56 changes: 43 additions & 13 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(
global_grad_norm = None

# Initialize the batch loss manager
batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank)
batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device)

# Blast through batches
for epoch in range(args.current_epoch, args.num_epochs):
Expand Down Expand Up @@ -150,8 +150,12 @@ def train(
elapsed_time = time.time() - start
overall_throughput = batch_metrics.total_samples / elapsed_time
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
if args.device == "hpu":
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
malloc_retries = 0
else:
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
global_grad_norm = (
model.get_global_grad_norm()
if hasattr(model, "get_global_grad_norm")
Expand All @@ -173,8 +177,8 @@ def train(
"rank": dist.get_rank(),
"overall_throughput": overall_throughput,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated,
("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries,
"num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens,
"num_tokens_rank0": batch_metrics.total_length,
"batch_size": batch_metrics.total_samples,
Expand Down Expand Up @@ -206,7 +210,8 @@ def train(
global_step += 1
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()
if args.device != "hpu":
torch.cuda.empty_cache()
if args.checkpoint_at_epoch:
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
save_checkpoint(
Expand Down Expand Up @@ -284,17 +289,22 @@ def main(args):
args.model_type = model_conf.model_type

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if args.device == "hpu":
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
else:
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])

timeout = _get_collective_timeout()
if timeout is not None:
dist.init_process_group(timeout=timeout)
else:
dist.init_process_group()
backend = "hccl" if args.device == "hpu" else None
torch.distributed.init_process_group(backend=backend, timeout=timeout)


args.global_rank = dist.get_rank()
tensor = torch.ByteTensor([False]).cuda()
if args.device == "hpu":
tensor = torch.ByteTensor([False]).to('hpu')
else:
tensor = torch.ByteTensor([False]).cuda()
dist.all_reduce(tensor)
dist.barrier()

Expand Down Expand Up @@ -335,6 +345,7 @@ def main(args):
flash_enabled=flash_enabled,
noise_alpha=args.NEFTune_alpha,
lora_quant_bits=args.lora_quant_bits,
device=args.device,
)

args.base_model_args = m.base_model_args
Expand Down Expand Up @@ -410,6 +421,7 @@ def main(args):
fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
save_samples=args.save_samples,
fsdp_use_orig_params=fsdp_should_use_orig_params,
device=args.device,
)
# optimizer needs model that has been prepared by accelerator
# and then accelerator needs to be prepared AGAIN once optimizer is initialized
Expand Down Expand Up @@ -588,6 +600,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

command.append(
f"--device={train_args.device}"
)

logger.info("Running training command as subprocess: %s", " ".join(command))
process = None
interrupt: KeyboardInterrupt | Exception | None = None
Expand Down Expand Up @@ -789,8 +805,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Use Liger kernels for training.",
)
parser.add_argument(
"--device",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the context of PyTorch, we typically assume device refers to a singular device on a node containing several accelerators, e.g.: cuda:0, cuda:1, etc.

Since you're looking to allow for switching between HPU and CUDA style training, maybe we could change this to be --device_type? This way, we can also constrain the options to be either cuda (by default) or hpu, and error out when given an option which doesn't fall in the set of allowed devices.

type=str,
default=None,
help="PyTorch device to use.",
)

args = parser.parse_args()
set_random_seed(args.seed)

if args.device == "hpu":
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.distributed.hccl
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()

set_random_seed(args.seed, args.device)
main(args)

"""
Expand Down
44 changes: 42 additions & 2 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# Third Party
from peft import LoraConfig
from torch.optim import AdamW
from transformers import Mxfp4Config # pylint: disable=no-name-in-module
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
Expand All @@ -57,17 +56,20 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
device: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv Is device here intended to be a specific device, or the broader device type (cuda/hpu)? If it's the latter, we should rename this to device_type to be consistent with the expected usage. Otherwise, it will be easy to get this confused with a specific torch.device instance such as cuda:0.

):
self.lora_config = lora_config
self.noise_alpha = noise_alpha
self.tokenizer = tokenizer
self.distributed_framework = distributed_framework
self.device = device
quant_config = None

# check model type & set on the mclasss
self.is_gpt_oss = is_gpt_oss(model_path)
if self.is_gpt_oss:
# Third Party
from transformers import Mxfp4Config # pylint: disable=no-name-in-module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv How come this import is being moved from module-level to method-level?

quant_config = Mxfp4Config(dequantize=True)

# TODO: Add support for 8bit quantization
Expand Down Expand Up @@ -102,6 +104,19 @@ def __init__(

def _post_model_init(self):
"""Common initialization steps that should happen after model initialization."""

if self.device == "hpu" and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better UX to handle reading this env var differently because it's not intuitive now.
For example: HPU_ENABLE_TORCH_COMPILE=False is still enabling t.compile

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but it is pretty common way widely used in other repositories. In any case I am going to change it to command line control, do you have any objections?

cache_size_limit = 10*1000
torch._dynamo.config.cache_size_limit = cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit
self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False)
for layer in self.model.model.layers:
layer.compile(backend="hpu_backend", dynamic=False)
if os.environ.get("RANK", '0') == '0':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv Are you trying to log from the main process on a given node here? The RANK env here corresponds to the rank of the node, not the process. If you just want to log a message on the main process (local rank 0), consider using the log_rank_0 function from instructlab.training.utils.

logger.info(
f"torch.compile has been enabled"
)

self.reconcile_tokenizer()
if self.lora_config:
self.model = self.prepare_peft_model()
Expand Down Expand Up @@ -270,7 +285,11 @@ def _is_causal_lm_model(self) -> bool:
bool: True if the model is a causal language model, False otherwise.
"""
# Third Party
return "ForCausalLM" in self.model.__class__.__name__
if self.device != "hpu":
class_name = self.model.__class__.__name__
else:
class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__
return "ForCausalLM" in class_name

def reconcile_tokenizer(self):
if len(self.tokenizer) > self.model.config.vocab_size:
Expand Down Expand Up @@ -326,6 +345,17 @@ def reconcile_tokenizer(self):
):
self.model.config.eos_token_id = self.tokenizer.eos_token_id

if self.device == "hpu":
model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model
class_name = model.__class__.__name__

replace_no_split_modules = {
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
}

if class_name in replace_no_split_modules:
model._no_split_modules = replace_no_split_modules[class_name]

if not self._is_causal_lm_model():
raise ValueError(
f"Model must be a causal language model, got {type(self.model)}"
Expand Down Expand Up @@ -386,9 +416,17 @@ def compute_loss(
- Dataclass containing the raw pre-scaled losses
"""
# Forward pass to get logits
hpu_args = {}
if self.device == "hpu":
hpu_args = {
"use_flash_attention":True,
"lazy_mode":False,
}

output = self(
**inputs,
use_cache=False,
**hpu_args,
)

# Manual loss computation with reduction="none" following mini_trainer's exact approach
Expand Down Expand Up @@ -490,6 +528,7 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
device: Optional[str] = None,
):
super().__init__(
model_path=model_path,
Expand All @@ -499,6 +538,7 @@ def __init__(
flash_enabled=flash_enabled,
lora_config=lora_config,
lora_quant_bits=lora_quant_bits,
device=device,
)
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
self._post_model_init()
Expand Down
10 changes: 7 additions & 3 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def reduce_sum_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**(_deprecated_arguments if model.device=="hpu" else {}),
)

return_dict = isinstance(output, dict)
Expand Down Expand Up @@ -775,13 +776,16 @@ def _get_state_dict_patched(model, unwrap=False):
accelerator.get_state_dict = get_state_dict_unpatched


def set_random_seed(seed):
def set_random_seed(seed, device: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adjusting this so that it's either using is_hpu or device_type?

if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

if device == "hpu":
torch.hpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)


# TODO: move this to also live in the `Model` object
def save_checkpoint(
Expand Down