Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Notes
notes/*.ipynb
251 changes: 215 additions & 36 deletions genlm/backend/llm/vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
import logging
import warnings
import os

from genlm.backend.llm.base import AsyncLM
from genlm.backend.cache import OutputCache

try:
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
from vllm import AsyncLLMEngine, SamplingParams
from vllm.utils import Counter
from vllm.inputs import TokensPrompt

Expand Down Expand Up @@ -54,40 +55,95 @@ def __call__(self, past_token_ids, logits):
return logits

class AsyncVirtualLM(AsyncLM):
default_params = {
"max_tokens": 1,
"n": 1,
"detokenize": False,
"stop": None,
"ignore_eos": True,
}

def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
def __init__(
self,
async_llm_engine,
cache_size=0,
cache_opts={},
logprobs_per_request=256,
v1=False,
):
"""Initialize an `AsyncVirtualLM` instance.

Args:
async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}.
v1: if true sets the engine to V1, otherwise to V0
logprobs_per_request: used only in V1, selects the number of retrieved logprobs.

Note:
The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
"""
self.v1 = v1
self.async_llm_engine = async_llm_engine
self.tokenizer = async_llm_engine.engine.get_tokenizer()
self.default_params = {
"max_tokens": 1,
"n": 1,
"detokenize": False,
"stop": None,
"ignore_eos": True,
}
# Version specific modifications
if self.v1: # pragma: no cover
self.default_params["logprobs"] = ( # pragma: no cover
logprobs_per_request # Set the retrieved logprobs.
)
self.tokenizer = self._wrap_tokenizer( # pragma: no cover
async_llm_engine.tokenizer
) # Wrap tokenizer for V1. # pragma: no cover
async_llm_engine.log_stats = False # pragma: no cover
else:
self.tokenizer = async_llm_engine.engine.get_tokenizer()
async_llm_engine.engine.log_stats = False
self.request_counter = Counter()
self.cache = (
OutputCache(maxsize=cache_size, **cache_opts)
if cache_size > 0
else None
)

async_llm_engine.engine.log_stats = False

super().__init__(tokenizer=self.tokenizer)

def _wrap_tokenizer(self, tokenizer): # pragma: no cover
"""Wrap v1 tokenizer to be compatible with base class expectations.
Note that in V1 async_llm_engine.tokenizer is a TokenizerGroup object"""

class TokenizerWrapper: # pragma: no cover
def __init__(self, tokenizer): # pragma: no cover
# Access the underlying tokenizer from TokenizerGroup.
self._tokenizer = getattr(
tokenizer, "tokenizer", tokenizer
) # pragma: no cover
# Add compatibility attributes.
self.is_fast = (
True # Assume fast tokenizer for v1 # pragma: no cover
)
self.name_or_path = getattr(
self._tokenizer,
"name_or_path",
"unknown", # pragma: no cover
) # pragma: no cover

def __getattr__( # pragma: no cover
self, name
): # Retrieve the tokenizer from the TokenizerGroup object.
return getattr(self._tokenizer, name)

def __len__(self): # pragma: no cover
return len(self._tokenizer)

return TokenizerWrapper(tokenizer)

@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
def from_name(
cls,
model_name,
v1=False,
logprobs_per_request=256,
engine_opts=None,
**kwargs,
):
"""Create a `AsyncVirtualLM` instance from a model name.

Args:
Expand All @@ -98,66 +154,159 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):

Returns:
(AsyncVirtualLM): An `AsyncVirtualLM` instance.

Note: for GPT-OSS, vLLM >= 0.10.2 is required
"""
if not HAS_VLLM:
raise ImportError( # pragma: no cover
"vLLM not available. Install vLLM or use AsyncTransformer instead."
)

if engine_opts is not None and "enable_chunked_prefill" in engine_opts:
if engine_opts["enable_chunked_prefill"]:
if engine_opts["enable_chunked_prefill"]: # pragma: no cover
warnings.warn( # pragma: no cover
"Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's "
"custom sampling functionality."
)

engine_opts = {
"enable_prefix_caching": True,
"disable_log_requests": True,
"disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do.
**(engine_opts or {}),
}

engine = AsyncLLMEngine.from_engine_args(
if v1: # pragma: no cover
original_v1_env = os.environ.get(
"VLLM_USE_V1" # pragma: no cover
) # The Engine Type could have already been set as an environmental variable so we set it to either V1 or V0 (after copying it in order to reset it later).
os.environ["VLLM_USE_V1"] = "1" # pragma: no cover
from vllm.engine.arg_utils import (
AsyncEngineArgs,
) # The AsyncEngineArgs import is different in V1 and V0. # pragma: no cover

engine_opts = {
"enable_prefix_caching": True,
"max_logprobs": logprobs_per_request,
**(engine_opts or {}),
} # pragma: no cover
else:
original_v1_env = os.environ.get("VLLM_USE_V1")
os.environ["VLLM_USE_V1"] = "0"
from vllm import (
AsyncEngineArgs,
) # The AsyncEngineArgs import is different in V1 and V0.

engine_opts = {
"enable_prefix_caching": True,
"disable_log_requests": True, # Is it possible to remove this parameter? it is causing problems with vllm >= v 0.10.0.
"disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do.
**(engine_opts or {}),
}

engine = AsyncLLMEngine.from_engine_args( # Set up the engine.
AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
)

return cls(engine, **kwargs)
# Reset the environmental variable, so that it does not interfere with other instances.
if original_v1_env is not None:
os.environ["VLLM_USE_V1"] = original_v1_env
else:
os.environ.pop("VLLM_USE_V1", None)

return cls(
engine, v1=v1, logprobs_per_request=logprobs_per_request, **kwargs
)

@property
def underlying_model(self):
return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model
raise NotImplementedError # pragma: no cover

@property
def logits_processors(self):
return self._logits_processors # pragma: no cover

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.

Args:
token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.
token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

Returns:
result (torch.Tensor): Normalized log probability tensor.

Warning:
Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
For synchronous usage, use the `next_token_logprobs_sync()` method instead.
"""

assert isinstance(token_ids, list) and all(
isinstance(i, int) for i in token_ids
), "token_ids must be a list of token IDs."

key = tuple(token_ids)

if self.cache is not None and key in self.cache:
return self.cache[key]

result = await self._next_token_logprobs(key)
if self.v1: # pragma: no cover
result = await self._next_token_logprobs_v1(key) # pragma: no cover
else:
result = await self._next_token_logprobs_v0(key)

if self.cache is not None:
self.cache[key] = result

return result

async def _next_token_logprobs(self, token_ids):
async def _next_token_logprobs_v1(self, token_ids): # pragma: no cover
"""Request log probabilities of next token asynchronously.

Args:
token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.
token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

Returns:
(torch.Tensor): Normalized log probability tensor.
"""
req_id = str(next(self.request_counter))

# Convert token IDs to string for v1 compatibility. # pragma: no cover
prompt = self.tokenizer.decode(token_ids) # pragma: no cover

outputs = []
async for output in self.async_llm_engine.generate(
prompt=prompt,
sampling_params=SamplingParams(**self.default_params),
request_id=req_id,
): # pragma: no cover
if output.finished:
outputs.append(output)

# Extract logprobs from the output
# v1 provides logprobs in the output when logprobs parameter is set
output = outputs[0].outputs[0] # pragma: no cover
logprobs = output.logprobs

assert logprobs, (
"Log probs should have been retrieved at this point"
) # pragma: no cover
# v1 logprobs format: list of dicts with token_id -> logprob
vocab_size = len(self.tokenizer) # pragma: no cover
logprobs_tensor = torch.full(
(1, vocab_size),
-float("inf"),
dtype=torch.float32, # pragma: no cover
)

for token_id, logprob in logprobs[0].items(): # pragma: no cover
# Assign the logprobs to the top-k retrieved tokens in the vocabulary.
assert hasattr(logprob, "logprob"), (
"Logprob field is required"
) # pragma: no cover
logprobs_tensor[0, token_id] = logprob.logprob

# Right now we don't re-normalize! We might want to change this,
# the remaining mass can either be redistributed among the remaining tokens
# or among the selected ones.
logprobs = logprobs_tensor # pragma: no cover
return logprobs[
0
] # Return shape (vocab_size,) instead of (1, vocab_size) # pragma: no cover

async def _next_token_logprobs_v0(self, token_ids):
"""Request log probabilities of next token asynchronously.

Args:
token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

Returns:
(torch.Tensor): Normalized log probability tensor.
Expand Down Expand Up @@ -191,6 +340,7 @@ def next_token_logprobs_sync(self, token_ids):
Returns:
(torch.Tensor): Normalized log probability tensor.
"""
assert not self.v1 # Currently implemented only for V0.
return self.batch_next_token_logprobs_sync([token_ids])[0]

def batch_next_token_logprobs_sync(self, token_ids_list):
Expand All @@ -203,6 +353,7 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
Returns:
(torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
"""
assert not self.v1 # Currently implemented only for V0.
req_ids = []
req_id2processors = {}
for token_ids in token_ids_list:
Expand Down Expand Up @@ -242,7 +393,12 @@ def __del__(self):
def _cleanup_engine(self):
"""Clean up the vLLM engine and associated resources."""
if async_engine := getattr(self, "async_llm_engine", None):
async_engine.shutdown_background_loop()
if (
self.v1
): # The shutdown method is different in V1 and V0. # pragma: no cover
async_engine.shutdown() # pragma: no cover
else:
async_engine.shutdown_background_loop()
destroy_model_parallel()
destroy_distributed_environment()

Expand All @@ -266,14 +422,37 @@ async def sample(
Returns:
(list[int]): The sampled token IDs.
"""
if self.v1: # pragma: no cover
if isinstance(prompt_token_ids, list): # pragma: no cover
prompt_token_ids = self.tokenizer.decode(
prompt_token_ids
) # pragma: no cover
elif isinstance(prompt_token_ids, str): # pragma: no cover
pass
else: # pragma: no cover
raise ValueError(
f"Invalid prompt_ids_Type: {type(prompt_token_ids)}"
) # pragma: no cover
else:
prompt_token_ids = TokensPrompt(prompt_token_ids=prompt_token_ids)

# Question to check: Why do we need to use "byte_vocab"?
def decode_eos(eos_token_ids):
if self.v1: # pragma: no cover
return [
self.tokenizer.decode([i]) for i in eos_token_ids
] # pragma: no cover
else: # What is the adavntage of using "byte_vocab" instead of the tokenizer. Can we do this also with V1 ?
[self.byte_vocab[i].decode() for i in eos_token_ids]

async for output in self.async_llm_engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
prompt=prompt_token_ids,
sampling_params=SamplingParams(
n=1,
max_tokens=max_tokens,
temperature=temperature,
seed=seed,
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
stop=decode_eos(eos_token_ids),
),
request_id=str(next(self.request_counter)),
):
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"bitsandbytes; sys_platform == 'linux'",
"numba",
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
"triton>=3.2.0; sys_platform == 'linux'",
"triton>=3.2.0; sys_platform == 'linux'"
]

[project.optional-dependencies]
Expand All @@ -42,3 +42,9 @@ requires = ["setuptools>=64.0", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"

[tool.setuptools_scm]

[tool.coverage.report] # rn, V1 must be excluded from coverage as it requires compute capability >= 8, which is not available on the CI.
exclude_lines = [
"pragma: no cover",
"if self.v1:",
"if v1:"]
Loading