diff --git a/.gitignore b/.gitignore index 82f9275..b753619 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Notes +notes/*.ipynb diff --git a/genlm/backend/llm/vllm.py b/genlm/backend/llm/vllm.py index ba424e8..2e08e27 100644 --- a/genlm/backend/llm/vllm.py +++ b/genlm/backend/llm/vllm.py @@ -1,12 +1,13 @@ import torch import logging import warnings +import os from genlm.backend.llm.base import AsyncLM from genlm.backend.cache import OutputCache try: - from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs + from vllm import AsyncLLMEngine, SamplingParams from vllm.utils import Counter from vllm.inputs import TokensPrompt @@ -54,27 +55,47 @@ def __call__(self, past_token_ids, logits): return logits class AsyncVirtualLM(AsyncLM): - default_params = { - "max_tokens": 1, - "n": 1, - "detokenize": False, - "stop": None, - "ignore_eos": True, - } - - def __init__(self, async_llm_engine, cache_size=0, cache_opts={}): + def __init__( + self, + async_llm_engine, + cache_size=0, + cache_opts={}, + logprobs_per_request=256, + v1=False, + ): """Initialize an `AsyncVirtualLM` instance. Args: async_llm_engine (AsyncLLMEngine): The async vLLM engine instance. cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0. cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}. + v1: if true sets the engine to V1, otherwise to V0 + logprobs_per_request: used only in V1, selects the number of retrieved logprobs. Note: The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine. """ + self.v1 = v1 self.async_llm_engine = async_llm_engine - self.tokenizer = async_llm_engine.engine.get_tokenizer() + self.default_params = { + "max_tokens": 1, + "n": 1, + "detokenize": False, + "stop": None, + "ignore_eos": True, + } + # Version specific modifications + if self.v1: # pragma: no cover + self.default_params["logprobs"] = ( # pragma: no cover + logprobs_per_request # Set the retrieved logprobs. + ) + self.tokenizer = self._wrap_tokenizer( # pragma: no cover + async_llm_engine.tokenizer + ) # Wrap tokenizer for V1. # pragma: no cover + async_llm_engine.log_stats = False # pragma: no cover + else: + self.tokenizer = async_llm_engine.engine.get_tokenizer() + async_llm_engine.engine.log_stats = False self.request_counter = Counter() self.cache = ( OutputCache(maxsize=cache_size, **cache_opts) @@ -82,12 +103,47 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}): else None ) - async_llm_engine.engine.log_stats = False - super().__init__(tokenizer=self.tokenizer) + def _wrap_tokenizer(self, tokenizer): # pragma: no cover + """Wrap v1 tokenizer to be compatible with base class expectations. + Note that in V1 async_llm_engine.tokenizer is a TokenizerGroup object""" + + class TokenizerWrapper: # pragma: no cover + def __init__(self, tokenizer): # pragma: no cover + # Access the underlying tokenizer from TokenizerGroup. + self._tokenizer = getattr( + tokenizer, "tokenizer", tokenizer + ) # pragma: no cover + # Add compatibility attributes. + self.is_fast = ( + True # Assume fast tokenizer for v1 # pragma: no cover + ) + self.name_or_path = getattr( + self._tokenizer, + "name_or_path", + "unknown", # pragma: no cover + ) # pragma: no cover + + def __getattr__( # pragma: no cover + self, name + ): # Retrieve the tokenizer from the TokenizerGroup object. + return getattr(self._tokenizer, name) + + def __len__(self): # pragma: no cover + return len(self._tokenizer) + + return TokenizerWrapper(tokenizer) + @classmethod - def from_name(cls, model_name, engine_opts=None, **kwargs): + def from_name( + cls, + model_name, + v1=False, + logprobs_per_request=256, + engine_opts=None, + **kwargs, + ): """Create a `AsyncVirtualLM` instance from a model name. Args: @@ -98,6 +154,8 @@ def from_name(cls, model_name, engine_opts=None, **kwargs): Returns: (AsyncVirtualLM): An `AsyncVirtualLM` instance. + + Note: for GPT-OSS, vLLM >= 0.10.2 is required """ if not HAS_VLLM: raise ImportError( # pragma: no cover @@ -105,59 +163,150 @@ def from_name(cls, model_name, engine_opts=None, **kwargs): ) if engine_opts is not None and "enable_chunked_prefill" in engine_opts: - if engine_opts["enable_chunked_prefill"]: + if engine_opts["enable_chunked_prefill"]: # pragma: no cover warnings.warn( # pragma: no cover "Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's " "custom sampling functionality." ) - engine_opts = { - "enable_prefix_caching": True, - "disable_log_requests": True, - "disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do. - **(engine_opts or {}), - } - - engine = AsyncLLMEngine.from_engine_args( + if v1: # pragma: no cover + original_v1_env = os.environ.get( + "VLLM_USE_V1" # pragma: no cover + ) # The Engine Type could have already been set as an environmental variable so we set it to either V1 or V0 (after copying it in order to reset it later). + os.environ["VLLM_USE_V1"] = "1" # pragma: no cover + from vllm.engine.arg_utils import ( + AsyncEngineArgs, + ) # The AsyncEngineArgs import is different in V1 and V0. # pragma: no cover + + engine_opts = { + "enable_prefix_caching": True, + "max_logprobs": logprobs_per_request, + **(engine_opts or {}), + } # pragma: no cover + else: + original_v1_env = os.environ.get("VLLM_USE_V1") + os.environ["VLLM_USE_V1"] = "0" + from vllm import ( + AsyncEngineArgs, + ) # The AsyncEngineArgs import is different in V1 and V0. + + engine_opts = { + "enable_prefix_caching": True, + "disable_log_requests": True, # Is it possible to remove this parameter? it is causing problems with vllm >= v 0.10.0. + "disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do. + **(engine_opts or {}), + } + + engine = AsyncLLMEngine.from_engine_args( # Set up the engine. AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts) ) - return cls(engine, **kwargs) + # Reset the environmental variable, so that it does not interfere with other instances. + if original_v1_env is not None: + os.environ["VLLM_USE_V1"] = original_v1_env + else: + os.environ.pop("VLLM_USE_V1", None) + + return cls( + engine, v1=v1, logprobs_per_request=logprobs_per_request, **kwargs + ) @property def underlying_model(self): - return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model + raise NotImplementedError # pragma: no cover + + @property + def logits_processors(self): + return self._logits_processors # pragma: no cover async def next_token_logprobs(self, token_ids): """Request log probabilities of next token asynchronously with output caching. Args: - token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model. + token_ids (list[int]): A list of token IDs, representing a prompt to the language model. Returns: result (torch.Tensor): Normalized log probability tensor. - - Warning: - Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop. - For synchronous usage, use the `next_token_logprobs_sync()` method instead. """ + + assert isinstance(token_ids, list) and all( + isinstance(i, int) for i in token_ids + ), "token_ids must be a list of token IDs." + key = tuple(token_ids) if self.cache is not None and key in self.cache: return self.cache[key] - result = await self._next_token_logprobs(key) + if self.v1: # pragma: no cover + result = await self._next_token_logprobs_v1(key) # pragma: no cover + else: + result = await self._next_token_logprobs_v0(key) if self.cache is not None: self.cache[key] = result return result - async def _next_token_logprobs(self, token_ids): + async def _next_token_logprobs_v1(self, token_ids): # pragma: no cover """Request log probabilities of next token asynchronously. Args: - token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model. + token_ids (list[int]): A list of token IDs, representing a prompt to the language model. + + Returns: + (torch.Tensor): Normalized log probability tensor. + """ + req_id = str(next(self.request_counter)) + + # Convert token IDs to string for v1 compatibility. # pragma: no cover + prompt = self.tokenizer.decode(token_ids) # pragma: no cover + + outputs = [] + async for output in self.async_llm_engine.generate( + prompt=prompt, + sampling_params=SamplingParams(**self.default_params), + request_id=req_id, + ): # pragma: no cover + if output.finished: + outputs.append(output) + + # Extract logprobs from the output + # v1 provides logprobs in the output when logprobs parameter is set + output = outputs[0].outputs[0] # pragma: no cover + logprobs = output.logprobs + + assert logprobs, ( + "Log probs should have been retrieved at this point" + ) # pragma: no cover + # v1 logprobs format: list of dicts with token_id -> logprob + vocab_size = len(self.tokenizer) # pragma: no cover + logprobs_tensor = torch.full( + (1, vocab_size), + -float("inf"), + dtype=torch.float32, # pragma: no cover + ) + + for token_id, logprob in logprobs[0].items(): # pragma: no cover + # Assign the logprobs to the top-k retrieved tokens in the vocabulary. + assert hasattr(logprob, "logprob"), ( + "Logprob field is required" + ) # pragma: no cover + logprobs_tensor[0, token_id] = logprob.logprob + + # Right now we don't re-normalize! We might want to change this, + # the remaining mass can either be redistributed among the remaining tokens + # or among the selected ones. + logprobs = logprobs_tensor # pragma: no cover + return logprobs[ + 0 + ] # Return shape (vocab_size,) instead of (1, vocab_size) # pragma: no cover + + async def _next_token_logprobs_v0(self, token_ids): + """Request log probabilities of next token asynchronously. + + Args: + token_ids (list[int]): A list of token IDs, representing a prompt to the language model. Returns: (torch.Tensor): Normalized log probability tensor. @@ -191,6 +340,7 @@ def next_token_logprobs_sync(self, token_ids): Returns: (torch.Tensor): Normalized log probability tensor. """ + assert not self.v1 # Currently implemented only for V0. return self.batch_next_token_logprobs_sync([token_ids])[0] def batch_next_token_logprobs_sync(self, token_ids_list): @@ -203,6 +353,7 @@ def batch_next_token_logprobs_sync(self, token_ids_list): Returns: (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list. """ + assert not self.v1 # Currently implemented only for V0. req_ids = [] req_id2processors = {} for token_ids in token_ids_list: @@ -242,7 +393,12 @@ def __del__(self): def _cleanup_engine(self): """Clean up the vLLM engine and associated resources.""" if async_engine := getattr(self, "async_llm_engine", None): - async_engine.shutdown_background_loop() + if ( + self.v1 + ): # The shutdown method is different in V1 and V0. # pragma: no cover + async_engine.shutdown() # pragma: no cover + else: + async_engine.shutdown_background_loop() destroy_model_parallel() destroy_distributed_environment() @@ -266,14 +422,37 @@ async def sample( Returns: (list[int]): The sampled token IDs. """ + if self.v1: # pragma: no cover + if isinstance(prompt_token_ids, list): # pragma: no cover + prompt_token_ids = self.tokenizer.decode( + prompt_token_ids + ) # pragma: no cover + elif isinstance(prompt_token_ids, str): # pragma: no cover + pass + else: # pragma: no cover + raise ValueError( + f"Invalid prompt_ids_Type: {type(prompt_token_ids)}" + ) # pragma: no cover + else: + prompt_token_ids = TokensPrompt(prompt_token_ids=prompt_token_ids) + + # Question to check: Why do we need to use "byte_vocab"? + def decode_eos(eos_token_ids): + if self.v1: # pragma: no cover + return [ + self.tokenizer.decode([i]) for i in eos_token_ids + ] # pragma: no cover + else: # What is the adavntage of using "byte_vocab" instead of the tokenizer. Can we do this also with V1 ? + [self.byte_vocab[i].decode() for i in eos_token_ids] + async for output in self.async_llm_engine.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + prompt=prompt_token_ids, sampling_params=SamplingParams( n=1, max_tokens=max_tokens, temperature=temperature, seed=seed, - stop=[self.byte_vocab[i].decode() for i in eos_token_ids], + stop=decode_eos(eos_token_ids), ), request_id=str(next(self.request_counter)), ): diff --git a/pyproject.toml b/pyproject.toml index fe8605a..b95d5b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "bitsandbytes; sys_platform == 'linux'", "numba", "vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'", - "triton>=3.2.0; sys_platform == 'linux'", + "triton>=3.2.0; sys_platform == 'linux'" ] [project.optional-dependencies] @@ -42,3 +42,9 @@ requires = ["setuptools>=64.0", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] + +[tool.coverage.report] # rn, V1 must be excluded from coverage as it requires compute capability >= 8, which is not available on the CI. +exclude_lines = [ + "pragma: no cover", + "if self.v1:", + "if v1:"] diff --git a/tests/conftest.py b/tests/conftest.py index f8a4566..7d83fbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -153,6 +153,7 @@ def from_name(cls, model_name, llm_opts=None): "enable_prefix_caching": True, "disable_log_stats": True, "dtype": "float16", + "disable_async_output_proc": True, # Force the use of V0 **(llm_opts or {}), } llm = LLM(model=model_name, tokenizer=model_name, **llm_opts) diff --git a/tests/test_llm.py b/tests/test_llm.py index 81aefc7..a72e8f9 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -29,6 +29,25 @@ def async_llm(model_name): ) +@pytest.fixture(scope="module") +def async_llm_v1(model_name): + try: + capability = torch.cuda.get_device_capability(0) + if capability[0] < 8: + pytest.skip("vLLM V1 requires GPU with Compute Capability >= 8.0") + except Exception: + pytest.skip("CUDA unavailable or cannot access CUDA capability") + + return load_model_by_name( + model_name, + backend="vllm", + llm_opts={ + "v1": True, + "engine_opts": {"gpu_memory_utilization": 0.2, "dtype": "float16"}, + }, + ) + + @pytest.fixture(scope="module") def transformer_llm(model_name): return load_model_by_name( @@ -149,6 +168,29 @@ def test_batch_next_token_logprobs_agreement( ] +@cuda_only +@pytest.mark.asyncio # Need to run V1 with asyncio. For some reason gets messed up with multiple event loops. +async def test_v1_next_token_logprobs(async_llm_v1, reference_llm, token_ids_list): + """Test V1 logprobs against reference (on top-256 tokens only).""" + for token_ids in token_ids_list: + logprobs_v1 = await async_llm_v1.next_token_logprobs(token_ids) + logprobs_ref = await reference_llm.next_token_logprobs(token_ids) + + # Filter non-inf tokens. Note that V1 retrieves only the top-k tokens and sets the other to -inf. + valid_mask = logprobs_v1 != -float("inf") + if valid_mask.sum() <= 128: + pytest.skip("Less than 128 tokens to compare!") + + comparison = compare( + logprobs_v1[valid_mask].cpu().numpy(), logprobs_ref[valid_mask] + ) + + assert comparison.max_rel_err < 0.1, ( + token_ids + ) # Had to increase a bit the tollerance for V1. Note that the V1 engine might slightly differ in how the weights are handled. + assert comparison.pearson > 0.95, token_ids + + @pytest.mark.asyncio async def test_mock_async_llm(): mock_async_llm = MockAsyncLM.from_name("gpt2")