From 890591332e39d95a895b80ba0773b0eb97330e80 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 21 Oct 2025 12:59:15 +0200 Subject: [PATCH 1/2] CI Add simple FSDP tests Normal and QLoRA with 4bit bnb. 8bit doesn't work. --- .github/workflows/nightly.yml | 5 ++ Makefile | 4 + examples/sft/utils.py | 8 +- tests/training/fsdp_config.yaml | 25 ++++++ tests/training/training.py | 141 ++++++++++++++++++++++++++++++++ 5 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 tests/training/fsdp_config.yaml create mode 100644 tests/training/training.py diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index dabfea9035..f3095a4c60 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -108,6 +108,11 @@ jobs: source activate peft make tests_core_multi_gpu + - name: Run training on multi GPU + run: | + source activate peft + make tests_training + - name: Generate Report if: always() run: | diff --git a/Makefile b/Makefile index 70ba4b7f8e..3221649b5e 100644 --- a/Makefile +++ b/Makefile @@ -64,3 +64,7 @@ tests_regression: tests_torch_compile: python -m pytest tests/test_torch_compile.py $(if $(IS_GITHUB_CI),--report-log "compile_tests.log",) + +tests_training: + accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py $(if $(IS_GITHUB_CI),--report-log "training_fsdp.log",) + accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py --quant 4bit $(if $(IS_GITHUB_CI),--report-log "training_fsdp_4bit.log",) diff --git a/examples/sft/utils.py b/examples/sft/utils.py index 0e24796de4..c55697f7c3 100644 --- a/examples/sft/utils.py +++ b/examples/sft/utils.py @@ -97,7 +97,9 @@ def create_and_prepare_model(args, data_args, training_args): ): raise NotImplementedError("Unsloth is not supported in distributed training") - if args.use_4bit_quantization: + if args.use_4bit_quantization and args.use_8bit_quantization: + raise ValueError("You configured 4bit and 8bit quantization at the same time, please choose only one of them.") + elif args.use_4bit_quantization: compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype) @@ -115,8 +117,8 @@ def create_and_prepare_model(args, data_args, training_args): print("=" * 80) print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") print("=" * 80) - elif args.use_8bit_quantization: - bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) + elif args.use_8bit_quantization: + bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) if args.use_unsloth: if torch.xpu.is_available(): diff --git a/tests/training/fsdp_config.yaml b/tests/training/fsdp_config.yaml new file mode 100644 index 0000000000..4a688645f6 --- /dev/null +++ b/tests/training/fsdp_config.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: false + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/tests/training/training.py b/tests/training/training.py new file mode 100644 index 0000000000..6c0ae467a8 --- /dev/null +++ b/tests/training/training.py @@ -0,0 +1,141 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is a simple example of training a model with QLoRA. +""" + +import argparse +import os +import tempfile +from typing import Literal + +import torch +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) + +from peft import LoraConfig, get_peft_model + + +def print_if_process_zero(*args, **kwargs): + PartialState().print(*args, **kwargs) + + +def main(model_id: str, quant: Literal["4bit", "8bit"] | None = None, target_modules: list[str] | None = None): + if target_modules == ["all-linear"]: + target_modules = "all-linear" + + print_if_process_zero("=" * 50) + print_if_process_zero(f"{model_id=}, {quant=}, {target_modules=}") + print_if_process_zero("=" * 50) + + data = load_dataset("ybelkada/english_quotes_copy") + + if quant == "4bit": + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_type="bfloat16", + bnb_4bit_quant_storage="bfloat16", + bnb_4bit_use_double_quant=True, + ) + elif quant == "8bit": + quant_config = BitsAndBytesConfig(load_in_8bit=True) + # cannot use llm_int8_skip_modules=target_modules, see: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1634 + elif quant is None: + quant_config = None + else: + raise ValueError(f"Unsupported quantization: {quant}, expected one of '4bit', '8bit', or None") + + tokenizer = AutoTokenizer.from_pretrained(model_id) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_id, quantization_config=quant_config, dtype=torch.bfloat16, device_map={"": PartialState().process_index} + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=target_modules, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, peft_config) + print_if_process_zero(model) + if PartialState().is_local_main_process: + model.print_trainable_parameters() + + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = Trainer( + model=model, + train_dataset=data["train"], + optimizer_cls_and_kwargs=(torch.optim.SGD, {"lr": 2e-4}), + # FSDP with AdamW: + # > RuntimeError: output with shape [] doesn't match the broadcast shape [1] + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=15, + learning_rate=2e-4, + bf16=True, + logging_steps=5, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + trainer.train() + + if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + trainer.save_model(tmp_dir) + + # some checks + if PartialState().is_local_main_process: + files = os.listdir(tmp_dir) + assert "adapter_model.safetensors" in files + assert "adapter_config.json" in files + + final_log = trainer.state.log_history[-1] + assert final_log["train_loss"] < 10.0, f"Final loss is too high: {final_log['loss']}" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, required=False, default="Qwen/Qwen3-0.6B") + parser.add_argument("--quant", type=str, choices=["4bit", "8bit"], required=False, default=None) + parser.add_argument( + "--target_modules", + type=str, + nargs="+", + required=False, + default=None, + help="List of target modules for LoRA adaptation", + ) + args = parser.parse_args() + main(model_id=args.model_id, quant=args.quant, target_modules=args.target_modules) From 659b724e4a059a7fc3bc52bd5d5744ce55694cf1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 21 Oct 2025 13:01:13 +0200 Subject: [PATCH 2/2] Undo unrelated change --- examples/sft/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/sft/utils.py b/examples/sft/utils.py index c55697f7c3..0e24796de4 100644 --- a/examples/sft/utils.py +++ b/examples/sft/utils.py @@ -97,9 +97,7 @@ def create_and_prepare_model(args, data_args, training_args): ): raise NotImplementedError("Unsloth is not supported in distributed training") - if args.use_4bit_quantization and args.use_8bit_quantization: - raise ValueError("You configured 4bit and 8bit quantization at the same time, please choose only one of them.") - elif args.use_4bit_quantization: + if args.use_4bit_quantization: compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype) @@ -117,8 +115,8 @@ def create_and_prepare_model(args, data_args, training_args): print("=" * 80) print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") print("=" * 80) - elif args.use_8bit_quantization: - bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) + elif args.use_8bit_quantization: + bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) if args.use_unsloth: if torch.xpu.is_available():