-
Couldn't load subscription status.
- Fork 11
Implement vLLM FSDP LoRA hot-swapping integration #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jacobthebanana
wants to merge
92
commits into
master
Choose a base branch
from
jjt/lora-vllm-hotswap
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
92 commits
Select commit
Hold shift + click to select a range
904d1e1
Implemented baseline LoRA peft for one Nvidia GPU.
jacobthebanana 2ace67e
Added support for saving lora adapters.
jacobthebanana a25e667
save_utils: added support for non-FSDP optimizers.
jacobthebanana 65a2dbf
example_lora: highlighted current lora (non-fsdp) limitations.
jacobthebanana ed4c84f
Added instructions on LoRA on one GPU.
jacobthebanana 5a72392
Added example script for launching lora.
jacobthebanana e176ac8
Revised instructions on LoRA on one GPU.
jacobthebanana 2d869b0
Implemented LoRA FSDP.
jacobthebanana dc098d6
Reverted automatic formatter changes in README.md
jacobthebanana 5a1fd76
Eliminated non-FSDP logic from save_utils.
jacobthebanana 7e187bc
Moved lora config out of example config.yaml.
jacobthebanana 3eea331
Implemented LoRA benchmarking logic for worker.
jacobthebanana 906e4f3
model_utils: Refactored get_lora_model to reduce interface width. (th…
jacobthebanana 0c41535
test_modelling: moved text output to data/.
jacobthebanana f24d2fa
added example yaml config for lora benchmarking.
jacobthebanana 7d27d90
launch_benchmark: marked qos flag as optional.
jacobthebanana d22ea85
launch_benchmark: added option to limit number of jobs launched.
jacobthebanana 84b953a
launch_benchmark: implemented torch profiler integration.
jacobthebanana e1cda07
Merged changes from low CPU memory usage feature (#6) into jjt/lora-b…
adil-a 48f61d9
Revised launch_benchmark.py to use new profiling path.
jacobthebanana 9876ebe
Enabled automatic creation of data/trace folder.
jacobthebanana 5330871
Added instructions for profiling tools.
jacobthebanana 17e24bd
Merge remote-tracking branch 'origin/master' into jjt/lora-baseline
jacobthebanana 9982791
Cleaned up duplicate imports from merge.
jacobthebanana 9a76e80
Cleaned up duplicate imports from merge.
jacobthebanana ffa7067
Cleaned up parse_benchmark.py
jacobthebanana bd893e1
Integrated LoRA logic into llama_example.py.
jacobthebanana c2f346f
Moved lora_configs into train_parameters in config yaml. Adjusted doc…
jacobthebanana 56cb750
Revised handling of nproc-per-node in benchmark script.
jacobthebanana 97ddd8c
Included parameter_count info in benchmark output.
jacobthebanana 7c7a000
Implemented basic util for parsing benchmarking output.
jacobthebanana f33e89a
model_utils: Enabled low_cpu_mem_usage in auto model from_pretrained…
jacobthebanana 35bdbcd
launch_lora_benchmark.sh: implemented automatic identification of num…
jacobthebanana e6b2e59
requirements.txt: included accelerate to support low_cpu_mem loading.
jacobthebanana db148fa
benchmark.py: adjusted BenchmarkingDataset to avoid StopIteration exc…
jacobthebanana 35f6c5d
benchmark.py: added env var flag to toggle export_trace
jacobthebanana 4a1251b
parse_benchmark: included profiler table in output file.
jacobthebanana 79fd79b
get_lora_model_from_base_model: enabled peft for models loaded via lo…
jacobthebanana 5c25397
model_utils: revised dtype handling for peft-wrapped models.
jacobthebanana c19de82
parse_benchmark: implemented sorting of profiler table output.
jacobthebanana 7e13cde
Merged example_lora into examples/llama_example.pu
jacobthebanana 28d4ede
Added instructions related to parse_benchmark
jacobthebanana a863ed2
parse_benchmark: implemented aggregation across repeated metrics.
jacobthebanana eb3721a
Implemented non-LoRA profiling and benchmarking.
jacobthebanana 37f5dec
Various static typechecking and formatting fixes.
jacobthebanana 78c6faf
Implemented restoring LoRA train state from filesystem.
jacobthebanana aea2ed8
Included train step number in LoRA adapter output path.
jacobthebanana dad6553
Added reference throughput table to documentation.
jacobthebanana bbcda75
Added unit description to reference throughput table.
jacobthebanana d397488
Added unit description to reference throughput table.
jacobthebanana 35b97b8
Benchmark: added option to override max_length of pre-trained model.
jacobthebanana 6af7791
Deleted unused `accelerate` dependency from requirements.txt
jacobthebanana 97be477
Benchmark: added comment on max_length.
jacobthebanana b43e565
Benchmark: added comment on batch size.
jacobthebanana 607de70
Benchmark: added option to override batch size.
jacobthebanana bdef48f
Benchmark throughput documentation: revised word choices.
jacobthebanana 3294a39
LoRA Hot-Swap: Implemented vLLM integration test scaffolding and PyTe…
jacobthebanana 2bb7bad
LoRA Hot-Swap: Implemented vLLM LoRA hot-swap integration proof-of-co…
jacobthebanana 5d93afe
LoRA Hot-Swap: added additional fixtures to enhance readability.
jacobthebanana 02988a5
LoRA Hot-Swap: Deleted redundant np.asarray call in integration test …
jacobthebanana 5ad5d90
LoRA Hot-Swap: Updated test case documentations to reflect code reuse…
jacobthebanana afb321c
Moved profiling-tracking logic out of Trainer.
jacobthebanana 5babf6b
Eliminated hasattr check related to no_sync since FSDP is always enab…
jacobthebanana c1b31c4
Replaced peft fsdp_auto_wrap_policy to eliminate implicit `accelerate…
jacobthebanana f0b201c
Configured LoRA auto-wrap policy as off by default- enable the policy…
jacobthebanana 429ec5e
Revised punctuation in lora_requires_grad_policy_fn.
jacobthebanana afbc061
Renamed declarative `enable_lora` with descriptive `is_lora_enabled`.
jacobthebanana 7bc6f89
Merge commit 'afbc061' from jjt/lora-baseline into jjt/lora-vllm-hotswap
jacobthebanana 4936b1d
Added (request for comment) AbstractInferenceEngine interface and LoR…
jacobthebanana aa1fe8b
Renamed "inference" to "sampling".
jacobthebanana 675367b
Added reference sampling steps to llama_example. Added example sampli…
jacobthebanana ca2cad8
Added train_parameters.get("sampler").
jacobthebanana 649a4b8
[WIP] Implemented vLLM wrapper combining vectorlm and vLLM workers.
jacobthebanana ebb7bc9
vllm integration: Eliminated duplicate vllm ResultHandler.
jacobthebanana 1f1f88e
vllm integration [WIP]: Revised vectorlm-vllm concurrency handling.
jacobthebanana 11a1ba5
vllm integration [WIP]: Implemented inference during training.
jacobthebanana b697dc0
vllm integration [WIP]: Implemented lora hotswap.
jacobthebanana 112ea3c
vllm integration [WIP]: Moved sampler-related logic into Trainer.
jacobthebanana 07405dc
Merge remote-tracking branch 'origin/master' into jjt/lora-vllm-hotswap
jacobthebanana e707987
vllm integration: Added documentation on sampling engine.
jacobthebanana 61c39ad
vllm integration: Added documentation on sampling engine.
jacobthebanana 609c023
[WIP] vllm hotswapping: Implement minimum-viable wrapper for vllm/main.
jacobthebanana 9585c01
[WIP] vllm hotswapping: Reduced area of vLLM integration interface.
jacobthebanana 31464aa
vllm hotswapping [WIP]: Reduced area of vLLM integration interface.
jacobthebanana 059d57f
vllm hotswapping [WIP]: Refactored vLLM integration interface to mini…
jacobthebanana b5c6389
vllm hotswapping [WIP]: deleted unneded torch dist.barrier from llama…
jacobthebanana f506812
vllm hotswapping [WIP]: documentation fixes and cleanup.
jacobthebanana 3e27e84
vllm hotswapping [WIP]: cleaned up documentation related to multiproc…
jacobthebanana 879399f
vllm hotswapping [WIP]: cleaned up changes in llama_example.py.
jacobthebanana bc0ae52
vllm hotswapping [WIP]: added example gemma sampling config.
jacobthebanana 5e8944d
vllm hotswapping: Refactoring and cleanup.
jacobthebanana 2005a7d
vllm hotswapping: Moved Sampler import into conditional block to avoi…
jacobthebanana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,4 +9,5 @@ data/ | |
| **/*.pyc | ||
| /.cache | ||
| /.vscode | ||
| /data | ||
| /data | ||
| /env | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Efficient Sampling during training | ||
|
|
||
| Some training objectives, noteably PPO, require "sampling" from the language model many times during training. The most straightforward approach might be to invoke model.generate on the model from within the training loop. Nevertheless, there have been a number of alternative inference approaches, including vLLM and others, promising over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. | ||
|
|
||
| One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed, such that there isn't a straightforward way of updating these weights. Usually, updating the weights requires restarting the sampling engine, which sometimes take minutes. At the same time, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. To resolve this issue, we implemented techniques to "hot-swap" the model parameters that are used in the sampling process. | ||
|
|
||
| Additionally, it is not straightforward to ensure a consistently high GPU utilization when combining sampling with training. | ||
| This repository enables you to make the most out of all your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. | ||
|
|
||
| ## Example- Supervised fine-tuning | ||
|
|
||
| We provide a basic example that samples from the language model while fine-tuning using a basic causal language modelling objective. To run the example, uncomment the "sampler" section in your configuration yaml, choose a port for `nccl` coordination, and run the following command (not using torchrun): | ||
|
|
||
| ``` | ||
| export MASTER_ADDR=127.0.0.1 | ||
| export MASTER_PORT=19132 | ||
| python3 examples/llama_example_mp.py \ | ||
| --yaml_path configs/config.yaml \ | ||
| --world_size 2 | ||
| ``` | ||
|
|
||
| ## Bring your own training loop | ||
|
|
||
| While the reference implementation is only for supervised fine-tuning, we provide abstractions that make it easier for you to implement your own training loop- be it PPO RLHF, TWIST, or something else. The goal is to abstract away all the synchronization logic, so that a training loop you've built on one GPU could scale to multiple GPUs on the same server with minimal modifications. | ||
|
|
||
| To get started, refer to examples/llama_example.py and vectorlm/trainer.py. Usually, the vLLM Engine is accessible only from the rank 0, making synchronization challenging. When invoked through llama_example_mp, the `SamplingEngine` interface in VectorLM enables your training loop to access vLLM.LLM.generate from all ranks, returning the same result across all ranks. Note that because the synchronization barriers require all ranks to reach the synchronization point, you need to invoke `generate` from all ranks. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=llama7b-2 | ||
| #SBATCH --nodes=1 | ||
| #SBATCH --mem=0 | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --cpus-per-gpu=6 | ||
| #SBATCH --gres=gpu:4 | ||
| #SBATCH --output=llama-2-7b.%j.out | ||
| #SBATCH --error=llama-2-7b.%j.err | ||
| #SBATCH --partition=a100 | ||
| #SBATCH --qos=your_assigned_qos # CHANGE | ||
| #SBATCH --open-mode=append | ||
| #SBATCH --wait-all-nodes=1 | ||
| #SBATCH --time=3-00 | ||
|
|
||
| export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. | ||
| export NCCL_DEBUG=WARN | ||
| export NCCL_DEBUG_SUBSYS=WARN | ||
|
|
||
| # export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication | ||
| # export TORCH_CPP_LOG_LEVEL=INFO | ||
| export LOGLEVEL=INFO | ||
| export PYTHONFAULTHANDLER=1 | ||
| # export CUDA_LAUNCH_BLOCKING=0 | ||
|
|
||
| torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} example_lora.py --yaml_path configs/config-lora.yaml |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=llama7b-2-lora | ||
| #SBATCH --nodes=1 | ||
| #SBATCH --mem=32GB | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --cpus-per-gpu=6 | ||
| #SBATCH --gres=gpu:1 | ||
| #SBATCH --output=llama-2-7b-lora.%j.out | ||
| #SBATCH --error=llama-2-7b-lora.%j.err | ||
| #SBATCH --partition=a100 | ||
| #SBATCH --qos=your_assigned_qos # CHANGE | ||
| #SBATCH --open-mode=append | ||
| #SBATCH --wait-all-nodes=1 | ||
| #SBATCH --time=3-00 | ||
|
|
||
| export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. | ||
| export NCCL_DEBUG=WARN | ||
| export NCCL_DEBUG_SUBSYS=WARN | ||
|
|
||
| # export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication | ||
| # export TORCH_CPP_LOG_LEVEL=INFO | ||
| export LOGLEVEL=INFO | ||
| export PYTHONFAULTHANDLER=1 | ||
| # export CUDA_LAUNCH_BLOCKING=0 | ||
|
|
||
| torchrun --nnodes=1 --nproc-per-node=1 example_lora.py --yaml_path configs/config-lora.yaml |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| """Supply LoRASamplingEngine to llama_example. | ||
|
|
||
| Each non-rank-0 worker process should spawn vectorlm logic in a | ||
| separate thread (but same process) but won't run the actual | ||
| vectorlm logic until the vLLM Engine is initialized- inference | ||
| weights loaded into each worker. | ||
|
|
||
| To do so without duplicating vLLM code, observe that only the main process | ||
| (rank 0) is aware that vLLM engine was initialized properly | ||
| (when LLMEngine.__init__ returns.) Hence, one way to implement this | ||
| setup is to block the vectorlm thread with a multiprocessing synchronization | ||
| feature (e.g., a Barrier shared across all processes) that the rank 0 process | ||
| can remotely unblock. | ||
|
|
||
| See docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30 | ||
| for more detail on this architecture. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import os | ||
| from functools import partial | ||
|
|
||
| from llama_example import main | ||
| from vllm import EngineArgs | ||
| from vllm.executor.multiproc_worker_utils import ResultHandler, mp | ||
|
|
||
| from vectorlm.sampling import ( | ||
| LoRASamplingEngine, | ||
| SamplingEngineProvider, | ||
| SynchronizationBarriers, | ||
| ) | ||
| from vectorlm.utils.data_utils import Config | ||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--world_size", type=int, default=1) | ||
| parser.add_argument("--yaml_path", type=str, required=True) | ||
| args = parser.parse_args() | ||
|
|
||
| world_size: int = args.world_size | ||
| vectorlm_config = Config(yaml_path=args.yaml_path) | ||
| sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[reportAttributeAccessIssue] | ||
| vllm_engine_config = EngineArgs( | ||
| model=vectorlm_config.model, # type: ignore[reportAttributeAccessIssue] | ||
| gpu_memory_utilization=sampler_config.get( | ||
| "gpu_memory_utilization", | ||
| 0.35, | ||
| ), | ||
| tensor_parallel_size=world_size, | ||
| dtype=sampler_config.get("vllm_dtype", "auto"), | ||
| enable_lora=True, | ||
| ).create_engine_config() | ||
| os.environ["WORLD_SIZE"] = str(world_size) | ||
|
|
||
| # Block all N vectorlm threads until main process finished | ||
| # initializing vLLM Engine. Additionally, block vectorlm | ||
| # threads as long as vLLM tasks are running. | ||
| barriers = SynchronizationBarriers( | ||
| # (n+1) threads: __main__, and n vectorlm threads (including main). | ||
| vllm_init=mp.Barrier(world_size + 1), | ||
| # n vectorlm threads. | ||
| before_generation=mp.Barrier(world_size), | ||
| after_generation=mp.Barrier(world_size), | ||
| ) | ||
| vllm_result_handler = ResultHandler() | ||
|
|
||
| # rank 0 worker runs in the __main__ process. | ||
| # all other ranks use one process each. | ||
| # vectorlm logic in each ranks (including rank 0) is in a separate thread | ||
| # from the vLLM worker logic. | ||
| vllm_callback_wrapper = SamplingEngineProvider( | ||
| vllm_engine_config, | ||
| barriers, | ||
| LoRASamplingEngine, | ||
| partial(main, vectorlm_config), | ||
| ) | ||
|
|
||
| vllm_callback_wrapper.initialize_engine() | ||
| assert vllm_callback_wrapper.llm is not None | ||
| output = vllm_callback_wrapper.llm.generate("Vector Institute is") | ||
| print(output[0].prompt + output[0].outputs[0].text) | ||
|
|
||
| vllm_callback_wrapper.join_vectorlm_thread() |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from .abstract import AbstractSamplingEngine | ||
| from .sampling_lora import LoRASamplingEngine | ||
| from .utils import ( | ||
| ManagedLLM, | ||
| ManagedMultiProcGPUExecutor, | ||
| SamplingEngineProvider, | ||
| SynchronizationBarriers, | ||
| handle_sample, | ||
| multiprocess_wrap, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.