Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f5c4ed4
Add lora in vllm & some tests
Nov 20, 2025
67b9651
add batched method in async + more tests
Nov 21, 2025
04b6b70
decrease difference error for lora because of precision issues (e.g. …
Nov 21, 2025
bec6fa5
set lora_request as class attribute
Nov 24, 2025
b129c83
change hf backend to support lora + add testing
Nov 25, 2025
78b2039
clean hf lora tests
Nov 25, 2025
450cd2a
add testing for swapping lora and no-lora
Nov 25, 2025
e2a6a81
remove unnecessary import
Nov 25, 2025
3416104
remove double batch method
Nov 25, 2025
2a5f93e
add comments in the new methods
Dec 3, 2025
245743a
remove comment
Dec 3, 2025
ab0860f
add more tests
Dec 5, 2025
f357d0f
update dependencies
Dec 8, 2025
31334ff
cleaning
Dec 8, 2025
b345d53
change model for testing
Dec 9, 2025
809bf82
add lora dependencies in pytest
Dec 9, 2025
c00c8d7
fix dependencies lora
Dec 9, 2025
eae7c8a
change test model
Dec 9, 2025
797c8d6
fix lora test on transformer
Dec 10, 2025
998af61
increase gpu memory util
Dec 10, 2025
18e114e
decrease gpu memory util
Dec 10, 2025
daddc42
check gpu github
Dec 11, 2025
a951037
change gpu memory util
Dec 11, 2025
bff9a75
debug github
Dec 11, 2025
42f0402
decrease tests
Dec 11, 2025
51f2e08
downgrade triton
Dec 12, 2025
ca8bd0f
trition 3.2
Dec 12, 2025
ec5ceb1
debug models github
Dec 12, 2025
d766049
change model on tests
Dec 12, 2025
5568324
remove test for cache reasons
Dec 12, 2025
73955d2
free disk space
Dec 12, 2025
bda3699
triton
Dec 12, 2025
709d279
add testing for error path
Dec 12, 2025
cc0ba95
add readme
Dec 12, 2025
c52faf2
cleaning
Dec 12, 2025
565f87f
triton reinstall
Dec 15, 2025
bc151d5
triton 3.2
Dec 15, 2025
24c8dee
remove unnecessary reinstall
Dec 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: |
python -m venv venv
source venv/bin/activate
pip install -e .[test]
pip install -e .[lora]
pip install -r requirements-dev.txt
- name: Run tests
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/[email protected]
with:
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: false

- uses: actions/setup-python@v4
with:
Expand All @@ -27,6 +37,6 @@ jobs:
run: |
python -m venv venv
source venv/bin/activate
pip install -e .[test]
pip install -e .[lora]
pip install -r requirements-dev.txt
python -m pytest tests --ignore=tests/test_mlx.py
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Or to install with MLX support, run:
pip install genlm-backend[mlx]
```

Or to install with LoRA support, run:

```bash
pip install genlm-backend[lora]
```

## 🧪 Example: Autobatched Sequential Importance Sampling with LLMs

Expand Down
31 changes: 31 additions & 0 deletions genlm/backend/llm/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,37 @@ def cache_kv(self, prompt_tokens):
result = self.model(torch.tensor([prompt_tokens]).to(self.device))
node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
node.past_key_values = result.past_key_values

def load_lora(self, lora_path, lora_name='lora_1'):
"""Load a LoRA adapter into the base model.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Name to assign to the loaded adapter.

Notes:
This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
"""
self.model.load_adapter(lora_path, lora_name)

def set_lora(self, lora_name='lora_1'):
"""Activate a previously loaded LoRA adapter.

Args:
lora_name (str): Name of the LoRA adapter to activate.

"""
self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter(lora_name)

def clear_lora(self):
"""
Deactivate all LoRA adapters.
"""
self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter([])

@torch.no_grad()
def batch_evaluate_queries(self):
Expand Down
23 changes: 22 additions & 1 deletion genlm/backend/llm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

try:
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
from vllm.lora.request import LoRARequest
from vllm.utils import Counter
from vllm.inputs import TokensPrompt

Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
if cache_size > 0
else None
)
self.lora_request = None

async_llm_engine.engine.log_stats = False

Expand Down Expand Up @@ -128,6 +130,22 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
def underlying_model(self):
return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model

def clear_lora(self):
"""
Disable any active LoRA adapter for the vLLM engine.
"""
self.lora_request = None

def set_lora(self, lora_path, lora_name="current_lora", lora_id=1):
"""Configure a LoRA adapter request for the vLLM engine.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
lora_id (int): Globally unique ID for the adapter.
"""
self.lora_request = LoRARequest(lora_name, lora_id, lora_path)

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.

Expand Down Expand Up @@ -172,6 +190,7 @@ async def _next_token_logprobs(self, token_ids):
sampling_params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
):
if output.finished:
Expand Down Expand Up @@ -215,11 +234,12 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
)

while self.async_llm_engine.engine.has_unfinished_requests():
output = self.async_llm_engine.engine.step()
output = self.async_llm_engine.engine.step()
for out in output:
if out.finished:
assert out.request_id in req_id2processors, (
Expand Down Expand Up @@ -275,6 +295,7 @@ async def sample(
seed=seed,
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
),
lora_request=self.lora_request,
request_id=str(next(self.request_counter)),
):
if output.finished:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ dependencies = [
"bitsandbytes; sys_platform == 'linux'",
"numba",
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
"triton>=3.2.0; sys_platform == 'linux'",
"triton==3.2.0; sys_platform == 'linux'",
]

[project.optional-dependencies]
mlx = [
"mlx",
"mlx-lm"
]
lora = [
'peft'
]
docs = [
"mkdocs",
"mkdocstrings[python]",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
destroy_model_parallel,
destroy_distributed_environment,
)
from vllm.lora.request import LoRARequest

HAS_VLLM = True
except ImportError:
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(self, llm):
stop=None,
ignore_eos=True,
)
self.lora_request = None

self.llm.llm_engine.log_stats = False

Expand All @@ -158,11 +160,18 @@ def from_name(cls, model_name, llm_opts=None):
llm = LLM(model=model_name, tokenizer=model_name, **llm_opts)
return cls(llm)

def clear_lora(self):
self.lora_request = None

def set_lora(self, lora_path, lora_name="current_lora", lora_id=1):
self.lora_request = LoRARequest(lora_name, lora_id, lora_path)

def next_token_logprobs_sync(self, token_ids):
outputs = self.llm.generate(
prompts=TokensPrompt(prompt_token_ids=token_ids),
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand All @@ -185,6 +194,7 @@ async def batch_next_token_logprobs(self, token_ids_list):
prompts=prompts,
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand Down
Loading