Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ jobs:
source activate peft
make tests_core_multi_gpu

- name: Run training on multi GPU
run: |
source activate peft
make tests_training

- name: Generate Report
if: always()
run: |
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ tests_regression:

tests_torch_compile:
python -m pytest tests/test_torch_compile.py $(if $(IS_GITHUB_CI),--report-log "compile_tests.log",)

tests_training:
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py $(if $(IS_GITHUB_CI),--report-log "training_fsdp.log",)
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py --quant 4bit $(if $(IS_GITHUB_CI),--report-log "training_fsdp_4bit.log",)
25 changes: 25 additions & 0 deletions tests/training/fsdp_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: false
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
141 changes: 141 additions & 0 deletions tests/training/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is a simple example of training a model with QLoRA.
"""

import argparse
import os
import tempfile
from typing import Literal

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import LoraConfig, get_peft_model


def print_if_process_zero(*args, **kwargs):
PartialState().print(*args, **kwargs)


def main(model_id: str, quant: Literal["4bit", "8bit"] | None = None, target_modules: list[str] | None = None):
if target_modules == ["all-linear"]:
target_modules = "all-linear"

print_if_process_zero("=" * 50)
print_if_process_zero(f"{model_id=}, {quant=}, {target_modules=}")
print_if_process_zero("=" * 50)

data = load_dataset("ybelkada/english_quotes_copy")

if quant == "4bit":
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_type="bfloat16",
bnb_4bit_quant_storage="bfloat16",
bnb_4bit_use_double_quant=True,
)
elif quant == "8bit":
quant_config = BitsAndBytesConfig(load_in_8bit=True)
# cannot use llm_int8_skip_modules=target_modules, see:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1634
elif quant is None:
quant_config = None
else:
raise ValueError(f"Unsupported quantization: {quant}, expected one of '4bit', '8bit', or None")

tokenizer = AutoTokenizer.from_pretrained(model_id)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=quant_config, dtype=torch.bfloat16, device_map={"": PartialState().process_index}
)

peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
print_if_process_zero(model)
if PartialState().is_local_main_process:
model.print_trainable_parameters()

data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
optimizer_cls_and_kwargs=(torch.optim.SGD, {"lr": 2e-4}),
# FSDP with AdamW:
# > RuntimeError: output with shape [] doesn't match the broadcast shape [1]
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=15,
learning_rate=2e-4,
bf16=True,
logging_steps=5,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(tmp_dir)

# some checks
if PartialState().is_local_main_process:
files = os.listdir(tmp_dir)
assert "adapter_model.safetensors" in files
assert "adapter_config.json" in files

final_log = trainer.state.log_history[-1]
assert final_log["train_loss"] < 10.0, f"Final loss is too high: {final_log['loss']}"


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, required=False, default="Qwen/Qwen3-0.6B")
parser.add_argument("--quant", type=str, choices=["4bit", "8bit"], required=False, default=None)
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
required=False,
default=None,
help="List of target modules for LoRA adaptation",
)
args = parser.parse_args()
main(model_id=args.model_id, quant=args.quant, target_modules=args.target_modules)
Loading