diff --git a/atroposlib/envs/server_handling/trl_vllm_server.py b/atroposlib/envs/server_handling/trl_vllm_server.py index 81917713b..53e33624b 100644 --- a/atroposlib/envs/server_handling/trl_vllm_server.py +++ b/atroposlib/envs/server_handling/trl_vllm_server.py @@ -1,9 +1,18 @@ """ This is a server that interfaces with trl's vLLM server. +TRL's vLLM server is started via `trl vllm-serve --model ` and provides +a `/generate/` endpoint for text generation. This server handler adapts Atropos's +API server interface to work with TRL's server. + Developed with much help from @winglian when they worked on integrating Atropos into Axolotl. + +Limitations: + - Token-level logprobs are not available through TRL's server API. + If you need logprobs, use VLLMServer with a native vLLM server instead. """ +import asyncio import time import uuid @@ -21,7 +30,17 @@ class TrlVllmServer(APIServer): """ - A server that interfaces with trl's vLLM server. + A server that interfaces with TRL's vLLM server. + + TRL (Transformer Reinforcement Learning) provides a vLLM server via the + `trl vllm-serve` command. This class adapts that server's API to the + Atropos APIServer interface. + + Note: + This server does NOT support token-level logprobs. The TRL server's + `/generate/` endpoint returns completion token IDs but not their + associated logprobs. If you need logprobs for training (e.g., for + PPO or GRPO), use `VLLMServer` with a native vLLM server instead. """ def __init__(self, config: APIServerConfig): @@ -31,18 +50,56 @@ def __init__(self, config: APIServerConfig): async def check_server_status_task(self, chat_completion: bool = True): """ - TODO: Implement server health check for trl's vLLM server + Periodically check the health of the TRL vLLM server. + + This method runs in a loop, checking server availability every second. + It attempts to make a lightweight request to the server's generate + endpoint to verify it's responsive. + + Args: + chat_completion: Unused parameter, kept for API compatibility. """ - self.server_healthy = True + while True: + try: + async with aiohttp.ClientSession() as session: + # Try to reach the generate endpoint with minimal request + # TRL's server doesn't have a dedicated /health endpoint, + # so we check if the generate endpoint is responsive + async with session.post( + f"{self.config.base_url}/generate/", + json={ + "prompts": [""], + "n": 1, + "max_tokens": 1, + }, + timeout=aiohttp.ClientTimeout(total=10), + ) as response: + # Any response (even error) means server is reachable + # A 200 means it's fully operational + if response.status < 500: + self.server_healthy = True + else: + self.server_healthy = False + except ( + aiohttp.ClientError, + asyncio.TimeoutError, + Exception, + ): + self.server_healthy = False + await asyncio.sleep(1) async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: """ - Wrapper for the chat completion using the trl's vLLM server. + Wrapper for chat completion using TRL's vLLM server. + + Converts chat messages to a prompt using the tokenizer's chat template, + sends the request to TRL's /generate/ endpoint, and returns an + OpenAI-compatible ChatCompletion object. """ url = f"{self.config.base_url}/generate/" - prompt = kwargs.get("messages", []) + messages = kwargs.get("messages", []) prompt = self.tokenizer.apply_chat_template( - prompt, tokenize=False, add_generation_prompt=True + messages, tokenize=False, add_generation_prompt=True ) async with aiohttp.ClientSession() as session: async with session.post( @@ -57,9 +114,12 @@ async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: "min_p": kwargs.get("min_p", 0.0), "max_tokens": kwargs.get("max_tokens", 1024), }, + timeout=aiohttp.ClientTimeout(total=self.config.timeout), ) as response: - completions = await response.json() - completions = ChatCompletion( + response.raise_for_status() + result = await response.json() + + completion = ChatCompletion( id=str(uuid.uuid4()), object="chat.completion", created=int(time.time()), @@ -68,23 +128,28 @@ async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: Choice( finish_reason=( "stop" - if self.tokenizer.eos_token_id in completion + if self.tokenizer.eos_token_id in completion_ids else "length" ), index=i, message=ChatCompletionMessage( - content=self.tokenizer.decode(completion), + content=self.tokenizer.decode( + completion_ids, skip_special_tokens=True + ), role="assistant", ), ) - for i, completion in enumerate(completions["completion_ids"]) + for i, completion_ids in enumerate(result["completion_ids"]) ], ) - return completions + return completion async def _completion_wrapper(self, **kwargs) -> Completion: """ - Wrapper for the completion using the trl's vLLM server. + Wrapper for text completion using TRL's vLLM server. + + Sends a prompt to TRL's /generate/ endpoint and returns an + OpenAI-compatible Completion object. """ url = f"{self.config.base_url}/generate/" prompt = kwargs.get("prompt", "") @@ -101,9 +166,12 @@ async def _completion_wrapper(self, **kwargs) -> Completion: "min_p": kwargs.get("min_p", 0.0), "max_tokens": kwargs.get("max_tokens", 1024), }, + timeout=aiohttp.ClientTimeout(total=self.config.timeout), ) as response: - completions = await response.json() - completions = Completion( + response.raise_for_status() + result = await response.json() + + completion = Completion( id=str(uuid.uuid4()), object="text_completion", created=int(time.time()), @@ -112,21 +180,42 @@ async def _completion_wrapper(self, **kwargs) -> Completion: CompletionChoice( finish_reason=( "stop" - if self.tokenizer.eos_token_id in completion + if self.tokenizer.eos_token_id in completion_ids else "length" ), index=i, - text=self.tokenizer.decode(completion), + text=self.tokenizer.decode( + completion_ids, skip_special_tokens=True + ), ) - for i, completion in enumerate(completions["completion_ids"]) + for i, completion_ids in enumerate(result["completion_ids"]) ], ) - return completions + return completion async def _tokens_and_logprobs_completion_wrapper( self, **kwargs ) -> tuple[list, list, list, list]: """ - Wrapper for the tokens and logprobs completion using the openai client. + Token-level logprobs completion - NOT SUPPORTED by TRL's vLLM server. + + TRL's vLLM server (started via `trl vllm-serve`) does not expose + token-level logprobs through its `/generate/` endpoint. The server + returns only the generated token IDs, not their associated probabilities. + + If you need token-level logprobs for training algorithms like PPO or GRPO, + use one of these alternatives: + - `VLLMServer`: Direct interface to native vLLM server with full + logprob support via the `/generate` endpoint. + - `SGLangServer`: Interface to SGLang server with logprob support. + + Raises: + NotImplementedError: Always raised as this functionality is not + available through TRL's server API. """ - raise NotImplementedError("Not implemented for trl's vLLM server yet.") + raise NotImplementedError( + "Token-level logprobs are not supported by TRL's vLLM server. " + "TRL's /generate/ endpoint returns only completion token IDs, " + "not their associated logprobs. If you need logprobs for training, " + "use VLLMServer or SGLangServer with a native vLLM/SGLang server instead." + ) diff --git a/atroposlib/tests/test_trl_vllm_server.py b/atroposlib/tests/test_trl_vllm_server.py new file mode 100644 index 000000000..a3748629a --- /dev/null +++ b/atroposlib/tests/test_trl_vllm_server.py @@ -0,0 +1,138 @@ +"""Tests for TrlVllmServer implementation.""" + +from unittest.mock import patch + +import pytest + +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self): + self.eos_token_id = 2 + self.bos_token_id = 1 + + def encode(self, text, add_special_tokens=True): + """Simple character-based encoding for testing.""" + tokens = [ord(c) for c in text] + if add_special_tokens: + tokens = [self.bos_token_id] + tokens + return tokens + + def decode(self, tokens, skip_special_tokens=False): + """Simple character-based decoding for testing.""" + if skip_special_tokens: + tokens = [ + t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] + ] + return "".join([chr(t) if 31 < t < 127 else "" for t in tokens]) + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + """Simple chat template for testing.""" + result = "" + for msg in messages: + result += f"<{msg['role']}>{msg['content']}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +@pytest.fixture +def mock_config(): + """Create a mock server config.""" + return APIServerConfig( + api_key="test-key", + base_url="http://localhost:8000", + model_name="test-model", + timeout=30, + ) + + +@pytest.fixture +def mock_server(mock_config): + """Create a TrlVllmServer with mocked tokenizer.""" + with patch.object(TrlVllmServer, "__init__", lambda self, config: None): + server = TrlVllmServer.__new__(TrlVllmServer) + server.config = mock_config + server.tokenizer = MockTokenizer() + server.server_healthy = False + return server + + +class TestTrlVllmServer: + """Tests for TrlVllmServer.""" + + def test_init_creates_tokenizer(self, mock_config): + """Test that __init__ creates tokenizer from config.""" + with patch( + "atroposlib.envs.server_handling.trl_vllm_server.AutoTokenizer" + ) as mock_auto: + mock_auto.from_pretrained.return_value = MockTokenizer() + with patch.object( + TrlVllmServer.__bases__[0], "__init__", return_value=None + ): + TrlVllmServer(mock_config) + mock_auto.from_pretrained.assert_called_once_with( + mock_config.model_name + ) + + @pytest.mark.asyncio + async def test_tokens_and_logprobs_raises_not_implemented(self, mock_server): + """Test tokens_and_logprobs method raises NotImplementedError with message.""" + with pytest.raises(NotImplementedError) as exc_info: + await mock_server._tokens_and_logprobs_completion_wrapper() + + error_msg = str(exc_info.value) + assert "Token-level logprobs are not supported" in error_msg + assert "VLLMServer" in error_msg or "SGLangServer" in error_msg + + def test_completion_includes_text_attribute(self, mock_server): + """Verify that completion response has text attribute (addresses issue #183).""" + from openai.types.completion import CompletionChoice + + # Create a CompletionChoice like the _completion_wrapper does + choice = CompletionChoice( + finish_reason="stop", + index=0, + text="Hello World", + ) + + assert hasattr(choice, "text") + assert choice.text == "Hello World" + + def test_chat_completion_has_message_not_text(self, mock_server): + """Verify chat completion uses message attribute, not text (issue #183).""" + from openai.types.chat.chat_completion import ChatCompletionMessage, Choice + + choice = Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="Hello", + role="assistant", + ), + ) + + assert hasattr(choice, "message") + assert hasattr(choice.message, "content") + assert choice.message.content == "Hello" + + def test_server_has_required_methods(self, mock_server): + """Verify that TrlVllmServer has all required methods.""" + assert hasattr(mock_server, "check_server_status_task") + assert hasattr(mock_server, "_chat_completion_wrapper") + assert hasattr(mock_server, "_completion_wrapper") + assert hasattr(mock_server, "_tokens_and_logprobs_completion_wrapper") + assert callable(mock_server.check_server_status_task) + assert callable(mock_server._chat_completion_wrapper) + assert callable(mock_server._completion_wrapper) + assert callable(mock_server._tokens_and_logprobs_completion_wrapper) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])