Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 110 additions & 21 deletions atroposlib/envs/server_handling/trl_vllm_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""
This is a server that interfaces with trl's vLLM server.

TRL's vLLM server is started via `trl vllm-serve --model <model_name>` and provides
a `/generate/` endpoint for text generation. This server handler adapts Atropos's
API server interface to work with TRL's server.

Developed with much help from @winglian when they worked on integrating Atropos into Axolotl.

Limitations:
- Token-level logprobs are not available through TRL's server API.
If you need logprobs, use VLLMServer with a native vLLM server instead.
"""

import asyncio
import time
import uuid

Expand All @@ -21,7 +30,17 @@

class TrlVllmServer(APIServer):
"""
A server that interfaces with trl's vLLM server.
A server that interfaces with TRL's vLLM server.

TRL (Transformer Reinforcement Learning) provides a vLLM server via the
`trl vllm-serve` command. This class adapts that server's API to the
Atropos APIServer interface.

Note:
This server does NOT support token-level logprobs. The TRL server's
`/generate/` endpoint returns completion token IDs but not their
associated logprobs. If you need logprobs for training (e.g., for
PPO or GRPO), use `VLLMServer` with a native vLLM server instead.
"""

def __init__(self, config: APIServerConfig):
Expand All @@ -31,18 +50,56 @@ def __init__(self, config: APIServerConfig):

async def check_server_status_task(self, chat_completion: bool = True):
"""
TODO: Implement server health check for trl's vLLM server
Periodically check the health of the TRL vLLM server.

This method runs in a loop, checking server availability every second.
It attempts to make a lightweight request to the server's generate
endpoint to verify it's responsive.

Args:
chat_completion: Unused parameter, kept for API compatibility.
"""
self.server_healthy = True
while True:
try:
async with aiohttp.ClientSession() as session:
# Try to reach the generate endpoint with minimal request
# TRL's server doesn't have a dedicated /health endpoint,
# so we check if the generate endpoint is responsive
async with session.post(
f"{self.config.base_url}/generate/",
json={
"prompts": [""],
"n": 1,
"max_tokens": 1,
},
timeout=aiohttp.ClientTimeout(total=10),
) as response:
# Any response (even error) means server is reachable
# A 200 means it's fully operational
if response.status < 500:
self.server_healthy = True
else:
self.server_healthy = False
except (
aiohttp.ClientError,
asyncio.TimeoutError,
Exception,
):
self.server_healthy = False
await asyncio.sleep(1)

async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the chat completion using the trl's vLLM server.
Wrapper for chat completion using TRL's vLLM server.

Converts chat messages to a prompt using the tokenizer's chat template,
sends the request to TRL's /generate/ endpoint, and returns an
OpenAI-compatible ChatCompletion object.
"""
url = f"{self.config.base_url}/generate/"
prompt = kwargs.get("messages", [])
messages = kwargs.get("messages", [])
prompt = self.tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
messages, tokenize=False, add_generation_prompt=True
)
async with aiohttp.ClientSession() as session:
async with session.post(
Expand All @@ -57,9 +114,12 @@ async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"min_p": kwargs.get("min_p", 0.0),
"max_tokens": kwargs.get("max_tokens", 1024),
},
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
completions = await response.json()
completions = ChatCompletion(
response.raise_for_status()
result = await response.json()

completion = ChatCompletion(
id=str(uuid.uuid4()),
object="chat.completion",
created=int(time.time()),
Expand All @@ -68,23 +128,28 @@ async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
Choice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
if self.tokenizer.eos_token_id in completion_ids
else "length"
),
index=i,
message=ChatCompletionMessage(
content=self.tokenizer.decode(completion),
content=self.tokenizer.decode(
completion_ids, skip_special_tokens=True
),
role="assistant",
),
)
for i, completion in enumerate(completions["completion_ids"])
for i, completion_ids in enumerate(result["completion_ids"])
],
)
return completions
return completion

async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion using the trl's vLLM server.
Wrapper for text completion using TRL's vLLM server.

Sends a prompt to TRL's /generate/ endpoint and returns an
OpenAI-compatible Completion object.
"""
url = f"{self.config.base_url}/generate/"
prompt = kwargs.get("prompt", "")
Expand All @@ -101,9 +166,12 @@ async def _completion_wrapper(self, **kwargs) -> Completion:
"min_p": kwargs.get("min_p", 0.0),
"max_tokens": kwargs.get("max_tokens", 1024),
},
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
completions = await response.json()
completions = Completion(
response.raise_for_status()
result = await response.json()

completion = Completion(
id=str(uuid.uuid4()),
object="text_completion",
created=int(time.time()),
Expand All @@ -112,21 +180,42 @@ async def _completion_wrapper(self, **kwargs) -> Completion:
CompletionChoice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
if self.tokenizer.eos_token_id in completion_ids
else "length"
),
index=i,
text=self.tokenizer.decode(completion),
text=self.tokenizer.decode(
completion_ids, skip_special_tokens=True
),
)
for i, completion in enumerate(completions["completion_ids"])
for i, completion_ids in enumerate(result["completion_ids"])
],
)
return completions
return completion

async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for the tokens and logprobs completion using the openai client.
Token-level logprobs completion - NOT SUPPORTED by TRL's vLLM server.

TRL's vLLM server (started via `trl vllm-serve`) does not expose
token-level logprobs through its `/generate/` endpoint. The server
returns only the generated token IDs, not their associated probabilities.

If you need token-level logprobs for training algorithms like PPO or GRPO,
use one of these alternatives:
- `VLLMServer`: Direct interface to native vLLM server with full
logprob support via the `/generate` endpoint.
- `SGLangServer`: Interface to SGLang server with logprob support.

Raises:
NotImplementedError: Always raised as this functionality is not
available through TRL's server API.
"""
raise NotImplementedError("Not implemented for trl's vLLM server yet.")
raise NotImplementedError(
"Token-level logprobs are not supported by TRL's vLLM server. "
"TRL's /generate/ endpoint returns only completion token IDs, "
"not their associated logprobs. If you need logprobs for training, "
"use VLLMServer or SGLangServer with a native vLLM/SGLang server instead."
)
138 changes: 138 additions & 0 deletions atroposlib/tests/test_trl_vllm_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for TrlVllmServer implementation."""

from unittest.mock import patch

import pytest

from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer


class MockTokenizer:
"""Mock tokenizer for testing."""

def __init__(self):
self.eos_token_id = 2
self.bos_token_id = 1

def encode(self, text, add_special_tokens=True):
"""Simple character-based encoding for testing."""
tokens = [ord(c) for c in text]
if add_special_tokens:
tokens = [self.bos_token_id] + tokens
return tokens

def decode(self, tokens, skip_special_tokens=False):
"""Simple character-based decoding for testing."""
if skip_special_tokens:
tokens = [
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
]
return "".join([chr(t) if 31 < t < 127 else "" for t in tokens])

def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
"""Simple chat template for testing."""
result = ""
for msg in messages:
result += f"<{msg['role']}>{msg['content']}</{msg['role']}>"
if add_generation_prompt:
result += "<assistant>"
if tokenize:
return self.encode(result)
return result


@pytest.fixture
def mock_config():
"""Create a mock server config."""
return APIServerConfig(
api_key="test-key",
base_url="http://localhost:8000",
model_name="test-model",
timeout=30,
)


@pytest.fixture
def mock_server(mock_config):
"""Create a TrlVllmServer with mocked tokenizer."""
with patch.object(TrlVllmServer, "__init__", lambda self, config: None):
server = TrlVllmServer.__new__(TrlVllmServer)
server.config = mock_config
server.tokenizer = MockTokenizer()
server.server_healthy = False
return server


class TestTrlVllmServer:
"""Tests for TrlVllmServer."""

def test_init_creates_tokenizer(self, mock_config):
"""Test that __init__ creates tokenizer from config."""
with patch(
"atroposlib.envs.server_handling.trl_vllm_server.AutoTokenizer"
) as mock_auto:
mock_auto.from_pretrained.return_value = MockTokenizer()
with patch.object(
TrlVllmServer.__bases__[0], "__init__", return_value=None
):
TrlVllmServer(mock_config)
mock_auto.from_pretrained.assert_called_once_with(
mock_config.model_name
)

@pytest.mark.asyncio
async def test_tokens_and_logprobs_raises_not_implemented(self, mock_server):
"""Test tokens_and_logprobs method raises NotImplementedError with message."""
with pytest.raises(NotImplementedError) as exc_info:
await mock_server._tokens_and_logprobs_completion_wrapper()

error_msg = str(exc_info.value)
assert "Token-level logprobs are not supported" in error_msg
assert "VLLMServer" in error_msg or "SGLangServer" in error_msg

def test_completion_includes_text_attribute(self, mock_server):
"""Verify that completion response has text attribute (addresses issue #183)."""
from openai.types.completion import CompletionChoice

# Create a CompletionChoice like the _completion_wrapper does
choice = CompletionChoice(
finish_reason="stop",
index=0,
text="Hello World",
)

assert hasattr(choice, "text")
assert choice.text == "Hello World"

def test_chat_completion_has_message_not_text(self, mock_server):
"""Verify chat completion uses message attribute, not text (issue #183)."""
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice

choice = Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello",
role="assistant",
),
)

assert hasattr(choice, "message")
assert hasattr(choice.message, "content")
assert choice.message.content == "Hello"

def test_server_has_required_methods(self, mock_server):
"""Verify that TrlVllmServer has all required methods."""
assert hasattr(mock_server, "check_server_status_task")
assert hasattr(mock_server, "_chat_completion_wrapper")
assert hasattr(mock_server, "_completion_wrapper")
assert hasattr(mock_server, "_tokens_and_logprobs_completion_wrapper")
assert callable(mock_server.check_server_status_task)
assert callable(mock_server._chat_completion_wrapper)
assert callable(mock_server._completion_wrapper)
assert callable(mock_server._tokens_and_logprobs_completion_wrapper)


if __name__ == "__main__":
pytest.main([__file__, "-v"])