From 1bb444312f98eabc84e3b186cfc529f387b13d6b Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:45:28 -0500 Subject: [PATCH 01/15] Initial scaffold of mock OpenAI-compatible server --- .../benchmark/mock_llm_server/__init__.py | 14 + .../mock_llm_server/example_usage.py | 206 +++++++ .../mock_llm_server/mock_llm_server.py | 406 +++++++++++++ .../benchmark/mock_llm_server/run_server.py | 79 +++ tests/benchmark/test_mock_llm_server.py | 531 ++++++++++++++++++ 5 files changed, 1236 insertions(+) create mode 100644 nemoguardrails/benchmark/mock_llm_server/__init__.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/example_usage.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/run_server.py create mode 100644 tests/benchmark/test_mock_llm_server.py diff --git a/nemoguardrails/benchmark/mock_llm_server/__init__.py b/nemoguardrails/benchmark/mock_llm_server/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/example_usage.py b/nemoguardrails/benchmark/mock_llm_server/example_usage.py new file mode 100644 index 000000000..278ab8d94 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/example_usage.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of the Mock LLM Server. + +This script demonstrates how to interact with the running mock server +using standard HTTP requests and the OpenAI Python client. +""" + +import json +import time + +import requests + + +def test_with_requests(): + """Test the server using the requests library.""" + base_url = "http://localhost:8000" + + print("Testing Mock LLM Server with requests library...") + print("=" * 50) + + # Test health endpoint + try: + response = requests.get(f"{base_url}/health", timeout=5) + print(f"Health check: {response.status_code} - {response.json()}") + except requests.RequestException as e: + print(f"Health check failed: {e}") + print("Make sure the server is running: python run_server.py") + return + + # Test models endpoint + try: + response = requests.get(f"{base_url}/v1/models", timeout=5) + print(f"\\nModels: {response.status_code}") + models_data = response.json() + for model in models_data["data"]: + print(f" - {model['id']}") + except requests.RequestException as e: + print(f"Models request failed: {e}") + + # Test chat completion + try: + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + } + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nChat completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['message']['content']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Chat completion failed: {e}") + + # Test text completion + try: + completion_payload = { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + } + response = requests.post( + f"{base_url}/v1/completions", + json=completion_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nText completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['text']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Text completion failed: {e}") + + +def test_with_openai_client(): + """Test the server using the OpenAI Python client.""" + try: + import openai + except ImportError: + print("\\nOpenAI client not available. Install with: pip install openai") + return + + print("\\n" + "=" * 50) + print("Testing with OpenAI client library...") + print("=" * 50) + + # Configure client to use local server + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="dummy-key", # Server doesn't validate, but client requires it + ) + + try: + # Test chat completion + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello from OpenAI client!"}], + ) + print(f"Chat completion response: {response.choices[0].message.content}") + print( + f"Usage: prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}" + ) + + # Test text completion (if supported by client version) + try: + response = client.completions.create( + model="text-davinci-003", prompt="OpenAI client test: ", max_tokens=10 + ) + print(f"Text completion response: {response.choices[0].text}") + except Exception as e: + print(f"Text completion not supported in this OpenAI client version: {e}") + + except Exception as e: + print(f"OpenAI client test failed: {e}") + + +def benchmark_performance(): + """Simple performance benchmark.""" + print("\\n" + "=" * 50) + print("Performance Benchmark") + print("=" * 50) + + base_url = "http://localhost:8000" + num_requests = 10 + + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Benchmark test"}], + "max_tokens": 20, + } + + print(f"Making {num_requests} chat completion requests...") + + start_time = time.time() + successful_requests = 0 + + for i in range(num_requests): + try: + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + successful_requests += 1 + except requests.RequestException: + pass + + end_time = time.time() + total_time = end_time - start_time + + print(f"Results:") + print(f" Total requests: {num_requests}") + print(f" Successful requests: {successful_requests}") + print(f" Total time: {total_time:.2f} seconds") + print(f" Average time per request: {total_time/num_requests:.3f} seconds") + print(f" Requests per second: {num_requests/total_time:.2f}") + + +def main(): + """Main function to run all tests.""" + print("Mock LLM Server Example Usage") + print("=" * 50) + print("Make sure the server is running before running this script:") + print(" python run_server.py") + print() + + # Test with requests + test_with_requests() + + # Test with OpenAI client + test_with_openai_client() + + # Simple benchmark + benchmark_performance() + + print("\\nExample completed!") + + +if __name__ == "__main__": + main() diff --git a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py new file mode 100644 index 000000000..28725e724 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mock LLM FastAPI Server with OpenAI-compatible interface. + +This server provides dummy implementations of OpenAI API endpoints for testing +and benchmarking purposes. +""" + +import time +import uuid +from typing import Any, Dict, List, Optional, Union + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="1.0.0", +) + + +# Pydantic Models for Request/Response validation + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + name: Optional[str] = Field(None, description="The name of the author") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: List[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, List[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[Dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, List[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, List[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[Dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[Dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: List[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: List[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: List[Model] = Field(..., description="List of available models") + + +# Dummy data and helper functions + +DUMMY_MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "openai", + }, + { + "id": "text-davinci-003", + "object": "model", + "created": 1669599635, + "owned_by": "openai", + }, +] + +DUMMY_CHAT_RESPONSES = [ + "This is a mock response from the LLM server.", + "I'm a dummy AI assistant created for testing purposes.", + "This response is generated by a mock OpenAI-compatible server.", + "Hello! I'm responding with dummy data for benchmarking.", + "This is a simulated conversation response for testing.", +] + +DUMMY_COMPLETION_RESPONSES = [ + " This is a dummy text completion.", + " Here's some mock generated text.", + " This is a sample completion response.", + " Mock completion text for testing purposes.", + " Dummy text generated by the mock server.", +] + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_dummy_chat_response() -> str: + """Get a dummy chat response.""" + import random + + return random.choice(DUMMY_CHAT_RESPONSES) + + +def get_dummy_completion_response() -> str: + """Get a dummy completion response.""" + import random + + return random.choice(DUMMY_COMPLETION_RESPONSES) + + +# API Endpoints + + +@app.get("/") +async def root(): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "1.0.0", + "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(): + """List available models.""" + return ModelsResponse( + object="list", data=[Model(**model) for model in DUMMY_MODELS] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest): + """Create a chat completion.""" + # Validate model exists + available_models = [model["id"] for model in DUMMY_MODELS] + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + # Generate dummy response + response_content = get_dummy_chat_response() + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content, name=None), + finish_reason="stop", + ) + choices.append(choice) + + return ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions(request: CompletionRequest): + """Create a text completion.""" + # Validate model exists + available_models = [model["id"] for model in DUMMY_MODELS] + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_dummy_completion_response() + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + return CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "timestamp": int(time.time())} + + +if __name__ == "__main__": + uvicorn.run( + "mock_llm_server:app", host="0.0.0.0", port=8000, reload=True, log_level="info" + ) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py new file mode 100644 index 000000000..66f281932 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Startup script for the Mock LLM Server. + +This script starts the FastAPI server with configurable host and port settings. +""" + +import argparse +import os +import sys + +import uvicorn + +# Add the current directory to Python path to import the server module +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + + +def main(): + parser = argparse.ArgumentParser(description="Run the Mock LLM Server") + parser.add_argument( + "--host", + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + parser.add_argument( + "--reload", action="store_true", help="Enable auto-reload for development" + ) + parser.add_argument( + "--log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Log level (default: info)", + ) + + args = parser.parse_args() + + print(f"Starting Mock LLM Server on {args.host}:{args.port}") + print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") + print(f"Health check at: http://{args.host}:{args.port}/health") + print("Press Ctrl+C to stop the server") + + try: + uvicorn.run( + "mock_llm_server:app", + host=args.host, + port=args.port, + reload=args.reload, + log_level=args.log_level, + ) + except KeyboardInterrupt: + print("\nServer stopped by user") + except Exception as e: # pylint: disable=broad-except + print(f"Error starting server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py new file mode 100644 index 000000000..b74f9633d --- /dev/null +++ b/tests/benchmark/test_mock_llm_server.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the Mock LLM FastAPI Server. + +This module contains comprehensive tests for all endpoints and edge cases +of the OpenAI-compatible mock LLM server. +""" + +import json +import time +from typing import Any, Dict, List +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +# Import the server and its components +from mock_llm_server.mock_llm_server import ( + DUMMY_MODELS, + app, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + + +class TestMockLLMServer: + """Test class for the Mock LLM Server.""" + + @pytest.fixture + def client(self): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + @pytest.fixture + def valid_chat_request(self): + """Sample valid chat completion request.""" + return { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "temperature": 0.7, + } + + @pytest.fixture + def valid_completion_request(self): + """Sample valid text completion request.""" + return { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + "temperature": 0.8, + } + + # Root endpoint tests + def test_root_endpoint(self, client): + """Test the root endpoint returns correct information.""" + response = client.get("/") + assert response.status_code == 200 + + data = response.json() + assert data["message"] == "Mock LLM Server" + assert data["version"] == "1.0.0" + assert "description" in data + assert "/v1/models" in data["endpoints"] + assert "/v1/chat/completions" in data["endpoints"] + assert "/v1/completions" in data["endpoints"] + + # Health check tests + def test_health_check(self, client): + """Test the health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], int) + + # Models endpoint tests + def test_list_models(self, client): + """Test the models listing endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "list" + assert isinstance(data["data"], list) + assert len(data["data"]) == len(DUMMY_MODELS) + + # Check first model structure + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + + def test_models_contain_expected_models(self, client): + """Test that all expected models are returned.""" + response = client.get("/v1/models") + data = response.json() + + model_ids = [model["id"] for model in data["data"]] + expected_ids = [model["id"] for model in DUMMY_MODELS] + + assert set(model_ids) == set(expected_ids) + + # Chat completions tests + def test_chat_completions_success(self, client, valid_chat_request): + """Test successful chat completion request.""" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == valid_chat_request["model"] + assert "id" in data + assert "created" in data + assert isinstance(data["created"], int) + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "message" in choice + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_chat_completions_multiple_choices(self, client, valid_chat_request): + """Test chat completion with multiple choices.""" + valid_chat_request["n"] = 3 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 3 + + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + assert choice["finish_reason"] == "stop" + + def test_chat_completions_invalid_model(self, client, valid_chat_request): + """Test chat completion with invalid model.""" + valid_chat_request["model"] = "invalid-model" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 400 + + data = response.json() + assert "detail" in data + assert "invalid-model" in data["detail"] + assert "not found" in data["detail"] + + def test_chat_completions_empty_messages(self, client): + """Test chat completion with empty messages.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [], + } + response = client.post("/v1/chat/completions", json=request_data) + # Note: The server currently accepts empty messages and processes them + # This may be acceptable behavior for a mock server + assert response.status_code in [ + 200, + 422, + ] # Allow both success and validation error + + def test_chat_completions_invalid_message_format(self, client): + """Test chat completion with invalid message format.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"invalid": "format"}], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 422 # Validation error + + def test_chat_completions_parameter_validation(self, client, valid_chat_request): + """Test parameter validation for chat completions.""" + # Test max_tokens validation + valid_chat_request["max_tokens"] = 0 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test temperature validation + valid_chat_request["max_tokens"] = 50 + valid_chat_request["temperature"] = 3.0 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test n validation + valid_chat_request["temperature"] = 0.7 + valid_chat_request["n"] = 200 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + def test_chat_completions_optional_parameters(self, client): + """Test chat completion with various optional parameters.""" + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Test message"}], + "max_tokens": 100, + "temperature": 0.5, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.2, + "stop": ["\\n"], + "user": "test-user", + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + # Text completions tests + def test_completions_success(self, client, valid_completion_request): + """Test successful text completion request.""" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + assert data["model"] == valid_completion_request["model"] + assert "id" in data + assert "created" in data + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "text" in choice + assert isinstance(choice["text"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + + def test_completions_list_prompt(self, client): + """Test text completion with list prompt.""" + request_data = { + "model": "text-davinci-003", + "prompt": ["First prompt", "Second prompt"], + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + + def test_completions_invalid_model(self, client, valid_completion_request): + """Test text completion with invalid model.""" + valid_completion_request["model"] = "non-existent-model" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 400 + + def test_completions_multiple_choices(self, client, valid_completion_request): + """Test text completion with multiple choices.""" + valid_completion_request["n"] = 2 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 2 + + def test_completions_parameter_validation(self, client, valid_completion_request): + """Test parameter validation for text completions.""" + # Test max_tokens validation + valid_completion_request["max_tokens"] = -1 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + # Test temperature validation + valid_completion_request["max_tokens"] = 10 + valid_completion_request["temperature"] = -1.0 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + def test_completions_optional_parameters(self, client): + """Test text completion with various optional parameters.""" + request_data = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + "max_tokens": 50, + "temperature": 0.8, + "top_p": 0.95, + "n": 1, + "logprobs": 1, + "echo": True, + "stop": ["\\n", "."], + "presence_penalty": -0.5, + "frequency_penalty": 0.3, + "best_of": 2, + "user": "test-user-2", + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + # Helper function tests + def test_generate_id_default(self): + """Test ID generation with default prefix.""" + id1 = generate_id() + id2 = generate_id() + + assert id1.startswith("chatcmpl-") + assert id2.startswith("chatcmpl-") + assert id1 != id2 # Should be unique + assert len(id1) == len("chatcmpl-") + 8 # prefix + 8 hex chars + + def test_generate_id_custom_prefix(self): + """Test ID generation with custom prefix.""" + custom_id = generate_id("cmpl") + assert custom_id.startswith("cmpl-") + assert len(custom_id) == len("cmpl-") + 8 + + def test_calculate_tokens(self): + """Test token calculation function.""" + # Test basic calculation + assert calculate_tokens("") == 1 # Minimum 1 token + assert calculate_tokens("a") == 1 + assert calculate_tokens("abcd") == 1 + assert calculate_tokens("abcde") == 1 # 5 chars = 1 token (rounded down) + assert calculate_tokens("abcdefgh") == 2 # 8 chars = 2 tokens + + # Test longer text + long_text = "This is a longer text with multiple words and characters." + expected_tokens = max(1, len(long_text) // 4) + assert calculate_tokens(long_text) == expected_tokens + + def test_get_dummy_responses(self): + """Test dummy response generation functions.""" + chat_response = get_dummy_chat_response() + assert isinstance(chat_response, str) + assert len(chat_response) > 0 + + completion_response = get_dummy_completion_response() + assert isinstance(completion_response, str) + assert len(completion_response) > 0 + + # Edge cases and error handling + def test_missing_required_fields_chat(self, client): + """Test chat completion with missing required fields.""" + # Missing model + response = client.post("/v1/chat/completions", json={"messages": []}) + assert response.status_code == 422 + + # Missing messages + response = client.post("/v1/chat/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_missing_required_fields_completion(self, client): + """Test text completion with missing required fields.""" + # Missing model + response = client.post("/v1/completions", json={"prompt": "test"}) + assert response.status_code == 422 + + # Missing prompt + response = client.post("/v1/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_invalid_json(self, client): + """Test endpoints with invalid JSON.""" + response = client.post( + "/v1/chat/completions", + content="invalid json", + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + def test_empty_request_body(self, client): + """Test endpoints with empty request body.""" + response = client.post("/v1/chat/completions", json={}) + assert response.status_code == 422 + + response = client.post("/v1/completions", json={}) + assert response.status_code == 422 + + # Content validation tests + def test_chat_message_content_types(self, client): + """Test chat completion with different message content types.""" + # Test with multiple messages + request_data = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + def test_response_structure_consistency(self, client, valid_chat_request): + """Test that response structure is consistent across calls.""" + response1 = client.post("/v1/chat/completions", json=valid_chat_request) + response2 = client.post("/v1/chat/completions", json=valid_chat_request) + + assert response1.status_code == 200 + assert response2.status_code == 200 + + data1 = response1.json() + data2 = response2.json() + + # Structure should be the same + assert set(data1.keys()) == set(data2.keys()) + assert data1["object"] == data2["object"] + assert data1["model"] == data2["model"] + + # IDs should be different + assert data1["id"] != data2["id"] + + def test_concurrent_requests(self, client, valid_chat_request): + """Test handling of concurrent requests.""" + import threading + import time + + results = [] + + def make_request(): + response = client.post("/v1/chat/completions", json=valid_chat_request) + results.append(response.status_code) + + # Create multiple threads + threads = [] + for _ in range(5): + thread = threading.Thread(target=make_request) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All requests should be successful + assert all(status == 200 for status in results) + assert len(results) == 5 + + # Performance and load tests + def test_response_time_reasonable(self, client, valid_chat_request): + """Test that response times are reasonable.""" + start_time = time.time() + response = client.post("/v1/chat/completions", json=valid_chat_request) + end_time = time.time() + + assert response.status_code == 200 + assert (end_time - start_time) < 1.0 # Should respond within 1 second + + def test_large_prompt_handling(self, client): + """Test handling of large prompts.""" + large_prompt = "A" * 10000 # 10K characters + request_data = { + "model": "text-davinci-003", + "prompt": large_prompt, + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + # Token calculation should handle large text + assert data["usage"]["prompt_tokens"] > 1000 + + # Mock and patch tests + @patch("mock_llm_server.mock_llm_server.get_dummy_chat_response") + def test_chat_response_mocking(self, mock_response, client, valid_chat_request): + """Test mocking of chat response generation.""" + mock_response.return_value = "Mocked response for testing" + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["message"]["content"] == "Mocked response for testing" + mock_response.assert_called_once() + + @patch("time.time") + def test_timestamp_consistency(self, mock_time, client, valid_chat_request): + """Test that timestamps are generated correctly.""" + mock_time.return_value = 1234567890 + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["created"] == 1234567890 + + # Documentation and OpenAPI tests + def test_openapi_docs_available(self, client): + """Test that OpenAPI documentation is available.""" + response = client.get("/docs") + assert response.status_code == 200 + + response = client.get("/openapi.json") + assert response.status_code == 200 + + openapi_data = response.json() + assert "openapi" in openapi_data + assert "paths" in openapi_data + assert "/v1/models" in openapi_data["paths"] + assert "/v1/chat/completions" in openapi_data["paths"] + assert "/v1/completions" in openapi_data["paths"] From d9b73bee71e053b5254a8d0be107ad1dfa375417 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:16:48 -0500 Subject: [PATCH 02/15] Refactor mock LLM, fix tests --- .../benchmark/mock_llm_server/api.py | 173 ++++++++ .../mock_llm_server/mock_llm_server.py | 406 ------------------ .../benchmark/mock_llm_server/models.py | 191 ++++++++ .../mock_llm_server/response_data.py | 79 ++++ .../benchmark/mock_llm_server/run_server.py | 7 +- tests/benchmark/test_mock_llm_server.py | 45 +- 6 files changed, 484 insertions(+), 417 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/api.py delete mode 100644 nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/models.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/response_data.py diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py new file mode 100644 index 000000000..bca45b1df --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import Union + +from fastapi import FastAPI, HTTPException + +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_MODELS, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + + +def _validate_request_model( + request: Union[CompletionRequest, ChatCompletionRequest], +) -> None: + """Check the Completion or Chat Completion `model` field is in our supported model list""" + available_models = set([model["id"] for model in DUMMY_MODELS]) + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="0.0.1", +) + + +@app.get("/") +async def root(): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "0.0.1", + "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(): + """List available models.""" + return ModelsResponse( + object="list", data=[Model(**model) for model in DUMMY_MODELS] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse: + """Create a chat completion.""" + # Validate model exists + _validate_request_model(request) + + # Generate dummy response + response_content = get_dummy_chat_response() + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content), + finish_reason="stop", + ) + choices.append(choice) + + response = ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + return response + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions(request: CompletionRequest) -> CompletionResponse: + """Create a text completion.""" + + # Validate model exists + _validate_request_model(request) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_dummy_completion_response() + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + response = CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "timestamp": int(time.time())} diff --git a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py deleted file mode 100644 index 28725e724..000000000 --- a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py +++ /dev/null @@ -1,406 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Mock LLM FastAPI Server with OpenAI-compatible interface. - -This server provides dummy implementations of OpenAI API endpoints for testing -and benchmarking purposes. -""" - -import time -import uuid -from typing import Any, Dict, List, Optional, Union - -import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field - -app = FastAPI( - title="Mock LLM Server", - description="OpenAI-compatible mock LLM server for testing and benchmarking", - version="1.0.0", -) - - -# Pydantic Models for Request/Response validation - - -class Message(BaseModel): - """Chat message model.""" - - role: str = Field(..., description="The role of the message author") - content: str = Field(..., description="The content of the message") - name: Optional[str] = Field(None, description="The name of the author") - - -class ChatCompletionRequest(BaseModel): - """Chat completion request model.""" - - model: str = Field(..., description="ID of the model to use") - messages: List[Message] = Field( - ..., description="List of messages comprising the conversation" - ) - max_tokens: Optional[int] = Field( - None, description="Maximum number of tokens to generate", ge=1 - ) - temperature: Optional[float] = Field( - 1.0, description="Sampling temperature", ge=0.0, le=2.0 - ) - top_p: Optional[float] = Field( - 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 - ) - n: Optional[int] = Field( - 1, description="Number of completions to generate", ge=1, le=128 - ) - stream: Optional[bool] = Field( - False, description="Whether to stream back partial progress" - ) - stop: Optional[Union[str, List[str]]] = Field( - None, description="Sequences where the API will stop generating" - ) - presence_penalty: Optional[float] = Field( - 0.0, description="Presence penalty", ge=-2.0, le=2.0 - ) - frequency_penalty: Optional[float] = Field( - 0.0, description="Frequency penalty", ge=-2.0, le=2.0 - ) - logit_bias: Optional[Dict[str, float]] = Field( - None, description="Modify likelihood of specified tokens" - ) - user: Optional[str] = Field( - None, description="Unique identifier representing your end-user" - ) - - -class CompletionRequest(BaseModel): - """Text completion request model.""" - - model: str = Field(..., description="ID of the model to use") - prompt: Union[str, List[str]] = Field( - ..., description="The prompt(s) to generate completions for" - ) - max_tokens: Optional[int] = Field( - 16, description="Maximum number of tokens to generate", ge=1 - ) - temperature: Optional[float] = Field( - 1.0, description="Sampling temperature", ge=0.0, le=2.0 - ) - top_p: Optional[float] = Field( - 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 - ) - n: Optional[int] = Field( - 1, description="Number of completions to generate", ge=1, le=128 - ) - stream: Optional[bool] = Field( - False, description="Whether to stream back partial progress" - ) - logprobs: Optional[int] = Field( - None, description="Include log probabilities", ge=0, le=5 - ) - echo: Optional[bool] = Field( - False, description="Echo back the prompt in addition to completion" - ) - stop: Optional[Union[str, List[str]]] = Field( - None, description="Sequences where the API will stop generating" - ) - presence_penalty: Optional[float] = Field( - 0.0, description="Presence penalty", ge=-2.0, le=2.0 - ) - frequency_penalty: Optional[float] = Field( - 0.0, description="Frequency penalty", ge=-2.0, le=2.0 - ) - best_of: Optional[int] = Field( - 1, description="Number of completions to generate server-side", ge=1 - ) - logit_bias: Optional[Dict[str, float]] = Field( - None, description="Modify likelihood of specified tokens" - ) - user: Optional[str] = Field( - None, description="Unique identifier representing your end-user" - ) - - -class Usage(BaseModel): - """Token usage information.""" - - prompt_tokens: int = Field(..., description="Number of tokens in the prompt") - completion_tokens: int = Field( - ..., description="Number of tokens in the completion" - ) - total_tokens: int = Field(..., description="Total number of tokens used") - - -class ChatCompletionChoice(BaseModel): - """Chat completion choice.""" - - index: int = Field(..., description="The index of this choice") - message: Message = Field(..., description="The generated message") - finish_reason: str = Field( - ..., description="The reason the model stopped generating" - ) - - -class CompletionChoice(BaseModel): - """Text completion choice.""" - - text: str = Field(..., description="The generated text") - index: int = Field(..., description="The index of this choice") - logprobs: Optional[Dict[str, Any]] = Field( - None, description="Log probability information" - ) - finish_reason: str = Field( - ..., description="The reason the model stopped generating" - ) - - -class ChatCompletionResponse(BaseModel): - """Chat completion response.""" - - id: str = Field(..., description="Unique identifier for the completion") - object: str = Field("chat.completion", description="Object type") - created: int = Field( - ..., description="Unix timestamp when the completion was created" - ) - model: str = Field(..., description="The model used for completion") - choices: List[ChatCompletionChoice] = Field( - ..., description="List of completion choices" - ) - usage: Usage = Field(..., description="Token usage information") - - -class CompletionResponse(BaseModel): - """Text completion response.""" - - id: str = Field(..., description="Unique identifier for the completion") - object: str = Field("text_completion", description="Object type") - created: int = Field( - ..., description="Unix timestamp when the completion was created" - ) - model: str = Field(..., description="The model used for completion") - choices: List[CompletionChoice] = Field( - ..., description="List of completion choices" - ) - usage: Usage = Field(..., description="Token usage information") - - -class Model(BaseModel): - """Model information.""" - - id: str = Field(..., description="Model identifier") - object: str = Field("model", description="Object type") - created: int = Field(..., description="Unix timestamp when the model was created") - owned_by: str = Field(..., description="Organization that owns the model") - - -class ModelsResponse(BaseModel): - """Models list response.""" - - object: str = Field("list", description="Object type") - data: List[Model] = Field(..., description="List of available models") - - -# Dummy data and helper functions - -DUMMY_MODELS = [ - { - "id": "gpt-3.5-turbo", - "object": "model", - "created": 1677610602, - "owned_by": "openai", - }, - {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, - { - "id": "gpt-4-turbo", - "object": "model", - "created": 1712361441, - "owned_by": "openai", - }, - { - "id": "text-davinci-003", - "object": "model", - "created": 1669599635, - "owned_by": "openai", - }, -] - -DUMMY_CHAT_RESPONSES = [ - "This is a mock response from the LLM server.", - "I'm a dummy AI assistant created for testing purposes.", - "This response is generated by a mock OpenAI-compatible server.", - "Hello! I'm responding with dummy data for benchmarking.", - "This is a simulated conversation response for testing.", -] - -DUMMY_COMPLETION_RESPONSES = [ - " This is a dummy text completion.", - " Here's some mock generated text.", - " This is a sample completion response.", - " Mock completion text for testing purposes.", - " Dummy text generated by the mock server.", -] - - -def generate_id(prefix: str = "chatcmpl") -> str: - """Generate a unique ID for completions.""" - return f"{prefix}-{uuid.uuid4().hex[:8]}" - - -def calculate_tokens(text: str) -> int: - """Rough token calculation (approximately 4 characters per token).""" - return max(1, len(text) // 4) - - -def get_dummy_chat_response() -> str: - """Get a dummy chat response.""" - import random - - return random.choice(DUMMY_CHAT_RESPONSES) - - -def get_dummy_completion_response() -> str: - """Get a dummy completion response.""" - import random - - return random.choice(DUMMY_COMPLETION_RESPONSES) - - -# API Endpoints - - -@app.get("/") -async def root(): - """Root endpoint with basic server information.""" - return { - "message": "Mock LLM Server", - "version": "1.0.0", - "description": "OpenAI-compatible mock LLM server for testing and benchmarking", - "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], - } - - -@app.get("/v1/models", response_model=ModelsResponse) -async def list_models(): - """List available models.""" - return ModelsResponse( - object="list", data=[Model(**model) for model in DUMMY_MODELS] - ) - - -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest): - """Create a chat completion.""" - # Validate model exists - available_models = [model["id"] for model in DUMMY_MODELS] - if request.model not in available_models: - raise HTTPException( - status_code=400, - detail=f"Model '{request.model}' not found. Available models: {available_models}", - ) - - # Generate dummy response - response_content = get_dummy_chat_response() - - # Calculate token usage - prompt_text = " ".join([msg.content for msg in request.messages]) - prompt_tokens = calculate_tokens(prompt_text) - completion_tokens = calculate_tokens(response_content) - - # Create response - completion_id = generate_id("chatcmpl") - created_timestamp = int(time.time()) - - choices = [] - for i in range(request.n or 1): - choice = ChatCompletionChoice( - index=i, - message=Message(role="assistant", content=response_content, name=None), - finish_reason="stop", - ) - choices.append(choice) - - return ChatCompletionResponse( - id=completion_id, - object="chat.completion", - created=created_timestamp, - model=request.model, - choices=choices, - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - - -@app.post("/v1/completions", response_model=CompletionResponse) -async def completions(request: CompletionRequest): - """Create a text completion.""" - # Validate model exists - available_models = [model["id"] for model in DUMMY_MODELS] - if request.model not in available_models: - raise HTTPException( - status_code=400, - detail=f"Model '{request.model}' not found. Available models: {available_models}", - ) - - # Handle prompt (can be string or list) - if isinstance(request.prompt, list): - prompt_text = " ".join(request.prompt) - else: - prompt_text = request.prompt - - # Generate dummy response - response_text = get_dummy_completion_response() - - # Calculate token usage - prompt_tokens = calculate_tokens(prompt_text) - completion_tokens = calculate_tokens(response_text) - - # Create response - completion_id = generate_id("cmpl") - created_timestamp = int(time.time()) - - choices = [] - for i in range(request.n or 1): - choice = CompletionChoice( - text=response_text, index=i, logprobs=None, finish_reason="stop" - ) - choices.append(choice) - - return CompletionResponse( - id=completion_id, - object="text_completion", - created=created_timestamp, - model=request.model, - choices=choices, - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy", "timestamp": int(time.time())} - - -if __name__ == "__main__": - uvicorn.run( - "mock_llm_server:app", host="0.0.0.0", port=8000, reload=True, log_level="info" - ) diff --git a/nemoguardrails/benchmark/mock_llm_server/models.py b/nemoguardrails/benchmark/mock_llm_server/models.py new file mode 100644 index 000000000..8634c46a6 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/models.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: list[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, list[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: list[Model] = Field(..., description="List of available models") diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py new file mode 100644 index 000000000..7e3c7e760 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid + +DUMMY_MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "openai", + }, + { + "id": "text-davinci-003", + "object": "model", + "created": 1669599635, + "owned_by": "openai", + }, +] + +DUMMY_CHAT_RESPONSES = [ + "This is a mock response from the LLM server.", + "I'm a dummy AI assistant created for testing purposes.", + "This response is generated by a mock OpenAI-compatible server.", + "Hello! I'm responding with dummy data for benchmarking.", + "This is a simulated conversation response for testing.", +] + +DUMMY_COMPLETION_RESPONSES = [ + " This is a dummy text completion.", + " Here's some mock generated text.", + " This is a sample completion response.", + " Mock completion text for testing purposes.", + " Dummy text generated by the mock server.", +] + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_dummy_chat_response() -> str: + """Get a dummy chat response.""" + import random + + return random.choice(DUMMY_CHAT_RESPONSES) + + +def get_dummy_completion_response() -> str: + """Get a dummy completion response.""" + import random + + return random.choice(DUMMY_COMPLETION_RESPONSES) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index 66f281932..83e68c049 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -25,9 +25,10 @@ import sys import uvicorn +from api import app -# Add the current directory to Python path to import the server module -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +# # Add the current directory to Python path to import the server module +# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) def main(): @@ -62,7 +63,7 @@ def main(): try: uvicorn.run( - "mock_llm_server:app", + app=app, host=args.host, port=args.port, reload=args.reload, diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index b74f9633d..ad24e5303 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -28,16 +28,27 @@ import pytest from fastapi.testclient import TestClient -# Import the server and its components -from mock_llm_server.mock_llm_server import ( +from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_CHAT_RESPONSES, DUMMY_MODELS, - app, calculate_tokens, generate_id, get_dummy_chat_response, get_dummy_completion_response, ) +# +# # Import the server and its components +# from mock_llm_server.mock_llm_server import ( +# DUMMY_MODELS, +# app, +# calculate_tokens, +# generate_id, +# get_dummy_chat_response, +# get_dummy_completion_response, +# ) + class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -75,7 +86,7 @@ def test_root_endpoint(self, client): data = response.json() assert data["message"] == "Mock LLM Server" - assert data["version"] == "1.0.0" + assert data["version"] == "0.0.1" assert "description" in data assert "/v1/models" in data["endpoints"] assert "/v1/chat/completions" in data["endpoints"] @@ -491,16 +502,34 @@ def test_large_prompt_handling(self, client): assert data["usage"]["prompt_tokens"] > 1000 # Mock and patch tests - @patch("mock_llm_server.mock_llm_server.get_dummy_chat_response") - def test_chat_response_mocking(self, mock_response, client, valid_chat_request): + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_chat_response") + def test_chat_completion_response_mocking( + self, mock_response, client, valid_chat_request + ): """Test mocking of chat response generation.""" - mock_response.return_value = "Mocked response for testing" + expected_response = "Mocked response for testing chat completions" + mock_response.return_value = expected_response response = client.post("/v1/chat/completions", json=valid_chat_request) assert response.status_code == 200 data = response.json() - assert data["choices"][0]["message"]["content"] == "Mocked response for testing" + assert data["choices"][0]["message"]["content"] == expected_response + mock_response.assert_called_once() + + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_completion_response") + def test_completion_response_mocking( + self, mock_response, client, valid_completion_request + ): + """Test mocking of chat response generation.""" + expected_response = "Mocked response to check completion responses" + mock_response.return_value = expected_response + + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["text"] == expected_response mock_response.assert_called_once() @patch("time.time") From 9021b81ab121fe24112903d6676e5b7c981173d4 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:59:41 -0500 Subject: [PATCH 03/15] Added tests to load YAML config. Still debugging dependency-injection of this into endpoints --- .../benchmark/mock_llm_server/api.py | 12 +++- .../benchmark/mock_llm_server/config.py | 72 +++++++++++++++++++ ...llama-3.1-nemoguard-8b-content-safety.yaml | 12 ++++ .../benchmark/mock_llm_server/run_server.py | 13 +++- tests/benchmark/mock_model_config.yaml | 3 + tests/benchmark/test_mock_llm_server.py | 45 +++++++++--- 6 files changed, 140 insertions(+), 17 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/config.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml create mode 100644 tests/benchmark/mock_model_config.yaml diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index bca45b1df..ca92ab193 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -15,10 +15,11 @@ import time -from typing import Union +from typing import Annotated, Union -from fastapi import FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config from nemoguardrails.benchmark.mock_llm_server.models import ( ChatCompletionChoice, ChatCompletionRequest, @@ -59,14 +60,19 @@ def _validate_request_model( ) +ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] + + @app.get("/") -async def root(): +async def root(current_config: ModelConfigDep): """Root endpoint with basic server information.""" + print(current_config) return { "message": "Mock LLM Server", "version": "0.0.1", "description": "OpenAI-compatible mock LLM server for testing and benchmarking", "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + "model_configuration": current_config, } diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py new file mode 100644 index 000000000..0b2fa42e6 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import lru_cache +from typing import Any, Optional, Union + +import yaml +from openai._utils import lru_cache +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class AppModelConfig(BaseModel): + """Pydantic model to configure the Mock LLM Server.""" + + # Mandatory fields + model: str = Field(..., description="Model name served by mock server") + refusal_text: str = Field(..., description="Refusal response text") + + # Config with default values + refusal_probability: float = Field( + default=0.1, description="Probability of refusal (between 0 and 1)" + ) + # Latency sampled from a truncated-normal distribution. + # Plain Normal distributions have infinite support, and can be negative + latency_min_seconds: float = Field( + default=0.1, description="Minimum latency in seconds" + ) + latency_max_seconds: float = Field( + default=5, description="Maximum latency in seconds" + ) + latency_mean_seconds: float = Field( + default=0.5, description="The average response time in seconds" + ) + latency_std_seconds: float = Field( + default=0.1, description="Standard deviation of response time" + ) + + +settings: Optional[AppModelConfig] = None + + +def load_config(yaml_file: str) -> None: + """Load the Model configuration from YAML file, store in global `settings` var""" + global settings + with open(yaml_file, "r") as f: + config_data = yaml.safe_load(f) + settings = AppModelConfig(**config_data) + + +@lru_cache +def get_config() -> AppModelConfig: + """FastAPI Dependency to inject model configuration""" + print(f"get_config called, settings = {settings}") + print(f"GET_CONFIG CALLED IN PROCESS ID: {os.getpid()}") + + if settings is None: + raise RuntimeError("No configuration loaded") + return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml new file mode 100644 index 000000000..2eebb1063 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -0,0 +1,12 @@ +model: "nvidia/llama-3.1-nemoguard-8b-content-safety" +refusal_probability: 0.01 +refusal_text: | + { + "User Safety": "unsafe", + "Response Safety": "unsafe", + "Safety Categories": "PII/Privacy" + } +latency_min_seconds: 0.1 +latency_max_seconds: 5 +latency_mean_seconds: 0.4 +latency_std_seconds: 0.1 diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index 83e68c049..d3732c97b 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -26,9 +26,7 @@ import uvicorn from api import app - -# # Add the current directory to Python path to import the server module -# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from config import get_config, load_config, settings def main(): @@ -54,11 +52,20 @@ def main(): help="Log level (default: info)", ) + parser.add_argument( + "--config-file", help="YAML file to configure model", required=True + ) + args = parser.parse_args() + # Load model configuration + load_config(args.config_file) + model_config = get_config() + print(f"Starting Mock LLM Server on {args.host}:{args.port}") print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") print(f"Health check at: http://{args.host}:{args.port}/health") + print(f"Model configuration: {model_config}") print("Press Ctrl+C to stop the server") try: diff --git a/tests/benchmark/mock_model_config.yaml b/tests/benchmark/mock_model_config.yaml new file mode 100644 index 000000000..384a988e5 --- /dev/null +++ b/tests/benchmark/mock_model_config.yaml @@ -0,0 +1,3 @@ +model: "mock_model" +refusal_probability: 0.01 +refusal_text: "I'm sorry, I can't help you with that request" diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index ad24e5303..0bc54423e 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -21,6 +21,7 @@ """ import json +import os import time from typing import Any, Dict, List from unittest.mock import patch @@ -29,6 +30,11 @@ from fastapi.testclient import TestClient from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.config import ( + AppModelConfig, + get_config, + load_config, +) from nemoguardrails.benchmark.mock_llm_server.response_data import ( DUMMY_CHAT_RESPONSES, DUMMY_MODELS, @@ -38,17 +44,6 @@ get_dummy_completion_response, ) -# -# # Import the server and its components -# from mock_llm_server.mock_llm_server import ( -# DUMMY_MODELS, -# app, -# calculate_tokens, -# generate_id, -# get_dummy_chat_response, -# get_dummy_completion_response, -# ) - class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -81,6 +76,17 @@ def valid_completion_request(self): # Root endpoint tests def test_root_endpoint(self, client): """Test the root endpoint returns correct information.""" + + mock_config = AppModelConfig( + model="mock_config_model_name", + refusal_text="I'm afraid I can't do that, Dave", + ) + + def override_get_config(): + return mock_config + + app.dependency_overrides[get_config] = override_get_config + response = client.get("/") assert response.status_code == 200 @@ -91,6 +97,8 @@ def test_root_endpoint(self, client): assert "/v1/models" in data["endpoints"] assert "/v1/chat/completions" in data["endpoints"] assert "/v1/completions" in data["endpoints"] + assert data["model_configuration"]["model"] == mock_config.model + assert data["model_configuration"]["refusal_text"] == mock_config.refusal_text # Health check tests def test_health_check(self, client): @@ -558,3 +566,18 @@ def test_openapi_docs_available(self, client): assert "/v1/models" in openapi_data["paths"] assert "/v1/chat/completions" in openapi_data["paths"] assert "/v1/completions" in openapi_data["paths"] + + def test_read_root_with_mock_config(self): + """Tests load_config method correctly populates the `settings` global variable""" + yaml_file = os.path.join(os.path.dirname(__file__), "mock_model_config.yaml") + + # Make sure settings is empty to start with, load and check it's populated + load_config(yaml_file) + config = get_config() + assert config is not None + + # Now check the contents against `mock_model_config.yaml` + assert isinstance(config, AppModelConfig) + assert config.model == "mock_model" + assert config.refusal_probability == 0.01 + assert config.refusal_text == "I'm sorry, I can't help you with that request" From 687e33bce7384bc67532fe206dcc266e1f3eca9f Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 17:08:05 -0500 Subject: [PATCH 04/15] Move FastAPI app import **after** the dependencies are loaded and cached --- nemoguardrails/benchmark/mock_llm_server/config.py | 2 -- nemoguardrails/benchmark/mock_llm_server/run_server.py | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 0b2fa42e6..9945cb98a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -18,7 +18,6 @@ from typing import Any, Optional, Union import yaml -from openai._utils import lru_cache from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -61,7 +60,6 @@ def load_config(yaml_file: str) -> None: settings = AppModelConfig(**config_data) -@lru_cache def get_config() -> AppModelConfig: """FastAPI Dependency to inject model configuration""" print(f"get_config called, settings = {settings}") diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index d3732c97b..0d05756d2 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -21,12 +21,11 @@ """ import argparse -import os import sys import uvicorn -from api import app -from config import get_config, load_config, settings + +from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config def main(): @@ -62,6 +61,9 @@ def main(): load_config(args.config_file) model_config = get_config() + # Import the app after configuration is loaded. This caches the values in the app Dependencies + from nemoguardrails.benchmark.mock_llm_server.api import app + print(f"Starting Mock LLM Server on {args.host}:{args.port}") print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") print(f"Health check at: http://{args.host}:{args.port}/health") From c0afd8d5096eb0c3ebfc8960b55c23c5e32c63e3 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 21:06:48 -0500 Subject: [PATCH 05/15] Remove debugging print statements --- nemoguardrails/benchmark/mock_llm_server/api.py | 1 - nemoguardrails/benchmark/mock_llm_server/config.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index ca92ab193..c34816ef2 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -66,7 +66,6 @@ def _validate_request_model( @app.get("/") async def root(current_config: ModelConfigDep): """Root endpoint with basic server information.""" - print(current_config) return { "message": "Mock LLM Server", "version": "0.0.1", diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 9945cb98a..0f1abe7bb 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -62,9 +62,6 @@ def load_config(yaml_file: str) -> None: def get_config() -> AppModelConfig: """FastAPI Dependency to inject model configuration""" - print(f"get_config called, settings = {settings}") - print(f"GET_CONFIG CALLED IN PROCESS ID: {os.getpid()}") - if settings is None: raise RuntimeError("No configuration loaded") return settings From e62f39421f7d60afb2eba05e0563beb03e38a6e8 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:48:31 -0500 Subject: [PATCH 06/15] Temporary checkin --- .../benchmark/mock_llm_server/api.py | 14 ++++-- .../mock_llm_server/response_data.py | 49 +++++++++++++++++-- tests/benchmark/test_mock_llm_server.py | 6 +++ 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index c34816ef2..a33b7505e 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -15,7 +15,7 @@ import time -from typing import Annotated, Union +from typing import Annotated, Optional, Union from fastapi import Depends, FastAPI, HTTPException @@ -84,13 +84,15 @@ async def list_models(): @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse: +async def chat_completions( + request: ChatCompletionRequest, config: ModelConfigDep +) -> ChatCompletionResponse: """Create a chat completion.""" # Validate model exists _validate_request_model(request) # Generate dummy response - response_content = get_dummy_chat_response() + response_content = get_dummy_chat_response(config) # Calculate token usage prompt_text = " ".join([msg.content for msg in request.messages]) @@ -127,7 +129,9 @@ async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResp @app.post("/v1/completions", response_model=CompletionResponse) -async def completions(request: CompletionRequest) -> CompletionResponse: +async def completions( + request: CompletionRequest, config: ModelConfigDep +) -> CompletionResponse: """Create a text completion.""" # Validate model exists @@ -140,7 +144,7 @@ async def completions(request: CompletionRequest) -> CompletionResponse: prompt_text = request.prompt # Generate dummy response - response_text = get_dummy_completion_response() + response_text = get_dummy_completion_response(config) # Calculate token usage prompt_tokens = calculate_tokens(prompt_text) diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index 7e3c7e760..abd1bc77c 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -14,7 +14,13 @@ # limitations under the License. +import random import uuid +from typing import Optional + +import numpy as np + +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config DUMMY_MODELS = [ { @@ -65,15 +71,50 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_dummy_chat_response() -> str: +def get_dummy_chat_response(config: AppModelConfig) -> str: """Get a dummy chat response.""" - import random + + if is_refusal(config): + return config.refusal_text return random.choice(DUMMY_CHAT_RESPONSES) -def get_dummy_completion_response() -> str: +def get_dummy_completion_response(config: AppModelConfig) -> str: """Get a dummy completion response.""" - import random + if is_refusal(config): + return config.refusal_text return random.choice(DUMMY_COMPLETION_RESPONSES) + + +def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> float: + """Sample latency for this request using the model's config + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + # Sample from the normal distribution using model config + latency_seconds = np.random.normal( + loc=config.latency_mean_seconds, scale=config.latency_std_seconds, size=1 + ) + + # Truncate distribution's support using min and max config values + latency_seconds = np.clip( + latency_seconds, + a_min=config.latency_min_seconds, + a_max=config.latency_max_seconds, + ) + return float(latency_seconds) + + +def is_refusal(config: AppModelConfig, seed: Optional[int] = None) -> bool: + """Check if the model should return a refusal + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + refusal = np.random.binomial(n=1, p=config.refusal_probability, size=1) + return bool(refusal[0]) diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index 0bc54423e..c53816de1 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -581,3 +581,9 @@ def test_read_root_with_mock_config(self): assert config.model == "mock_model" assert config.refusal_probability == 0.01 assert config.refusal_text == "I'm sorry, I can't help you with that request" + + @patch("nemoguardrails.benchmark.mock_llm_server.config.settings", None) + def test_get_config_raises_exception(self): + """Check if we call `get_config()` without settings set we raise an exception""" + with pytest.raises(RuntimeError, match="No configuration loaded"): + get_config() From 6ddcacac53109a6b6cf3125addfa9cab04accfec Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 11:25:44 -0500 Subject: [PATCH 07/15] Add refusal probability and tests to check it --- .../mock_llm_server/response_data.py | 20 +++++---- tests/benchmark/test_mock_llm_server.py | 43 +++++++++++++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index abd1bc77c..38522583a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -53,11 +53,11 @@ ] DUMMY_COMPLETION_RESPONSES = [ - " This is a dummy text completion.", - " Here's some mock generated text.", - " This is a sample completion response.", - " Mock completion text for testing purposes.", - " Dummy text generated by the mock server.", + "This is a dummy text completion.", + "Here's some mock generated text.", + "This is a sample completion response.", + "Mock completion text for testing purposes.", + "Dummy text generated by the mock server.", ] @@ -71,18 +71,20 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_dummy_chat_response(config: AppModelConfig) -> str: +def get_dummy_chat_response(config: AppModelConfig, seed: Optional[int] = None) -> str: """Get a dummy chat response.""" - if is_refusal(config): + if is_refusal(config, seed): return config.refusal_text return random.choice(DUMMY_CHAT_RESPONSES) -def get_dummy_completion_response(config: AppModelConfig) -> str: +def get_dummy_completion_response( + config: AppModelConfig, seed: Optional[int] = None +) -> str: """Get a dummy completion response.""" - if is_refusal(config): + if is_refusal(config, seed): return config.refusal_text return random.choice(DUMMY_COMPLETION_RESPONSES) diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index c53816de1..552eb57e1 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -37,6 +37,7 @@ ) from nemoguardrails.benchmark.mock_llm_server.response_data import ( DUMMY_CHAT_RESPONSES, + DUMMY_COMPLETION_RESPONSES, DUMMY_MODELS, calculate_tokens, generate_id, @@ -44,6 +45,20 @@ get_dummy_completion_response, ) +RANDOM_SEED = 12345 +REFUSAL_TEXT = "I'm sorry Dave, I'm afraid I can't do that" +NO_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=0.0, +) + +ALL_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=1.0, +) + class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -375,15 +390,25 @@ def test_calculate_tokens(self): expected_tokens = max(1, len(long_text) // 4) assert calculate_tokens(long_text) == expected_tokens - def test_get_dummy_responses(self): - """Test dummy response generation functions.""" - chat_response = get_dummy_chat_response() - assert isinstance(chat_response, str) - assert len(chat_response) > 0 - - completion_response = get_dummy_completion_response() - assert isinstance(completion_response, str) - assert len(completion_response) > 0 + def test_get_dummy_completion_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_completion_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_chat_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_chat_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_completion_response_no_refusal(self): + """Test /completion response generation with P = 0.0 of refusal""" + response = get_dummy_completion_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_COMPLETION_RESPONSES) + + def test_get_dummy_chat_response_no_refusal(self): + """Test /chat/completion response with P = 0.0 of refusal.""" + response = get_dummy_chat_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_CHAT_RESPONSES) # Edge cases and error handling def test_missing_required_fields_chat(self, client): From 3b3f49a42d9bfb9f3bb36929e93545b30026ac13 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:55:48 -0500 Subject: [PATCH 08/15] Use YAML configs for Nemoguard and app LLMs --- .../benchmark/mock_llm_server/api.py | 36 ++- .../benchmark/mock_llm_server/config.py | 9 +- .../content_safety/config.yml | 21 ++ .../content_safety/prompts.yml | 257 ++++++++++++++++++ .../configs/meta-llama-3.3-70b-instruct.yaml | 8 + ...llama-3.1-nemoguard-8b-content-safety.yaml | 9 +- .../benchmark/mock_llm_server/models.py | 4 +- .../mock_llm_server/response_data.py | 63 +---- 8 files changed, 324 insertions(+), 83 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index a33b7505e..80dedc728 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -33,23 +33,23 @@ Usage, ) from nemoguardrails.benchmark.mock_llm_server.response_data import ( - DUMMY_MODELS, calculate_tokens, generate_id, - get_dummy_chat_response, - get_dummy_completion_response, + get_response, ) +ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] + def _validate_request_model( + config: ModelConfigDep, request: Union[CompletionRequest, ChatCompletionRequest], ) -> None: """Check the Completion or Chat Completion `model` field is in our supported model list""" - available_models = set([model["id"] for model in DUMMY_MODELS]) - if request.model not in available_models: + if request.model != config.model: raise HTTPException( status_code=400, - detail=f"Model '{request.model}' not found. Available models: {available_models}", + detail=f"Model '{request.model}' not found. Available models: {config.model}", ) @@ -60,27 +60,25 @@ def _validate_request_model( ) -ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] - - @app.get("/") -async def root(current_config: ModelConfigDep): +async def root(config: ModelConfigDep): """Root endpoint with basic server information.""" return { "message": "Mock LLM Server", "version": "0.0.1", - "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "description": f"OpenAI-compatible mock LLM server for model: {config.model}", "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], - "model_configuration": current_config, + "model_configuration": config, } @app.get("/v1/models", response_model=ModelsResponse) -async def list_models(): +async def list_models(config: ModelConfigDep): """List available models.""" - return ModelsResponse( - object="list", data=[Model(**model) for model in DUMMY_MODELS] + model = Model( + id=config.model, object="model", created=int(time.time()), owned_by="system" ) + return ModelsResponse(object="list", data=[model]) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @@ -89,10 +87,10 @@ async def chat_completions( ) -> ChatCompletionResponse: """Create a chat completion.""" # Validate model exists - _validate_request_model(request) + _validate_request_model(config, request) # Generate dummy response - response_content = get_dummy_chat_response(config) + response_content = get_response(config) # Calculate token usage prompt_text = " ".join([msg.content for msg in request.messages]) @@ -135,7 +133,7 @@ async def completions( """Create a text completion.""" # Validate model exists - _validate_request_model(request) + _validate_request_model(config, request) # Handle prompt (can be string or list) if isinstance(request.prompt, list): @@ -144,7 +142,7 @@ async def completions( prompt_text = request.prompt # Generate dummy response - response_text = get_dummy_completion_response(config) + response_text = get_response(config) # Calculate token usage prompt_tokens = calculate_tokens(prompt_text) diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 0f1abe7bb..44a725d1b 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -27,12 +27,13 @@ class AppModelConfig(BaseModel): # Mandatory fields model: str = Field(..., description="Model name served by mock server") - refusal_text: str = Field(..., description="Refusal response text") + unsafe_probability: float = Field( + default=0.1, description="Probability of unsafe response (between 0 and 1)" + ) + unsafe_text: str = Field(..., description="Refusal response to unsafe prompt") + safe_text: str = Field(..., description="Safe response") # Config with default values - refusal_probability: float = Field( - default=0.1, description="Probability of refusal (between 0 and 1)" - ) # Latency sampled from a truncated-normal distribution. # Plain Normal distributions have infinite support, and can be negative latency_min_seconds: float = Field( diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml new file mode 100644 index 000000000..67df489d7 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml @@ -0,0 +1,21 @@ +models: + - type: main + engine: nim + model: meta/llama-3.3-70b-instruct + parameters: + base_url: http://localhost:8000 + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + parameters: + base_url: http://localhost:8001 + + +rails: + input: + flows: + - content safety check input $model=content_safety + output: + flows: + - content safety check output $model=content_safety diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml new file mode 100644 index 000000000..1321a6461 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml @@ -0,0 +1,257 @@ +# These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. +prompts: + + - task: content_safety_check_input $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_prompt_safety + max_tokens: 50 + + - task: content_safety_check_output $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + response: agent: {{ bot_response }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_response_safety + max_tokens: 50 + + - task: content_safety_check_input $model=llama_guard + content: | + [INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. + + O1: Violence and Hate. + Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. + Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence. + O2: Sexual Content. + Should not + - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. + Can + - Discuss topics related to sexuality and sexual education. + - Provide general information regarding sexual activities. + - Provide guidance on topics related to sexual health. + O3: Criminal Planning. + Should not + - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in criminal activities or condone criminal behavior. + Can + - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. + - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. + O4: Guns and Illegal Weapons. + Should not + - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. + Can + - Discuss firearms and the arguments for and against firearm ownership. + O5: Regulated or Controlled Substances. + Should not + - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. + - Assist or encourage people to create such substances when it is illegal to do so. + Can + - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). + - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. + O6: Self-Harm. + Should not + - Encourage people to harm themselves. + - Romanticize or condone self-harm or suicide. + - Provide information regarding the methods of suicide or self-harm. + - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. + Can + - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). + Should + - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. + O7: Offensive Language and Insults. + Should not + - Address direct insults + - Use or encourage the use of swear words, bad language, or vulgar expressions. + - Insult, offend, or demean individuals or groups based on any personal characteristics or choices. + - Encourage or support the use of offensive or derogatory language. + Can + - Discuss the impact and societal perceptions of offensive language and insults. + - Provide information on communication strategies that avoid the use of offensive language and promote respectful dialogue. + - Offer guidance on how to address or respond to offensive language or insults in a constructive manner. + + + + User: {{ user_input }} + + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + + - task: content_safety_check_output $model=llama_guard_2 + content: | + <|begin_of_text|>[INST] Task: Check if there is unsafe content in Agent messages in conversations according our safety policy with the below categories. + + + S1: Violent Crimes. + S2: Non-Violent Crimes. + S3: Sex Crimes. + S4: Child Exploitation. + S5: Specialized Advice. + S6: Privacy. + S7: Intellectual Property. + S8: Indiscriminate Weapons. + S9: Hate. + S10: Self-Harm. + S11: Sexual Content. + + + + + User: {{ user_input }} + + Agent: {{ bot_response }} + + + + Provide your safety assessment for ONLY THE LAST Agent in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_input $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a user prompt is in + violation of the defined safety policies. + + + Human Question: {{ user_input }} + + + Our safety principle is defined in the below: + + * "No Hate Speech": The prompt shall not contain or seek generation of content + that expresses, incites, or promotes hate based on race, gender, ethnicity, + religion, nationality, sexual orientation, disability status, or caste. + * "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself + and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content + that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, + content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted. + * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, + intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, + denial of tragic events, disparaging victims of violence). + === + + Does the human prompt/question violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_output $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a chatbot response to the user is in + violation of the defined safety policies: + + + Human Question: {{ user_input }} + Chatbot Response: {{ bot_response }} + + + Our safety principle is defined in the below: + + * "No Dangerous Content": The chatbot shall not generate content that harming oneself and/or others + (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Harassment": The chatbot shall not generate content that is malicious, intimidating, bullying, or + abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence). + * "No Hate Speech": The chatbot shall not generate content that targeting identity and/or protected + attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the + bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. + We do allow content that includes a reference to a protected individual or group in a positive manner. + * "No Sexually Explicit Information": The chatbot shall not generate content that contains references to sexual acts or + other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). + Medical or scientific terms related to human anatomy or sex education are permitted. + + === + + Does the chatbot response violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml new file mode 100644 index 000000000..32f987c20 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml @@ -0,0 +1,8 @@ +model: "meta/llama-3.3-70b-instruct" +unsafe_probability: 0.0 +unsafe_text: "I can't help with that. Is there anything else I can assist you with?" +safe_text : "I can provide information and help with a wide range of topics, from science and history to entertainment and culture. I can also help with language-related tasks, such as translation and text summarization. However, I can't assist with requests that involve harm or illegal activities." +latency_min_seconds: 0.3 +latency_max_seconds: 5 +latency_mean_seconds: 4 +latency_std_seconds: 0.1 diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml index 2eebb1063..fbded382a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -1,11 +1,16 @@ model: "nvidia/llama-3.1-nemoguard-8b-content-safety" -refusal_probability: 0.01 -refusal_text: | +unsafe_probability: 0.03 +unsafe_text: | { "User Safety": "unsafe", "Response Safety": "unsafe", "Safety Categories": "PII/Privacy" } +safe_text : | + { + "User Safety": "safe", + "Response Safety": "safe", + } latency_min_seconds: 0.1 latency_max_seconds: 5 latency_mean_seconds: 0.4 diff --git a/nemoguardrails/benchmark/mock_llm_server/models.py b/nemoguardrails/benchmark/mock_llm_server/models.py index 8634c46a6..aac72d6bb 100644 --- a/nemoguardrails/benchmark/mock_llm_server/models.py +++ b/nemoguardrails/benchmark/mock_llm_server/models.py @@ -146,7 +146,7 @@ class CompletionChoice(BaseModel): class ChatCompletionResponse(BaseModel): - """Chat completion response.""" + """Chat completion response - https://platform.openai.com/docs/api-reference/chat/object""" id: str = Field(..., description="Unique identifier for the completion") object: str = Field("chat.completion", description="Object type") @@ -161,7 +161,7 @@ class ChatCompletionResponse(BaseModel): class CompletionResponse(BaseModel): - """Text completion response.""" + """Text completion response. https://platform.openai.com/docs/api-reference/completions/object""" id: str = Field(..., description="Unique identifier for the completion") object: str = Field("text_completion", description="Object type") diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index 38522583a..b58bb768a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -22,44 +22,6 @@ from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config -DUMMY_MODELS = [ - { - "id": "gpt-3.5-turbo", - "object": "model", - "created": 1677610602, - "owned_by": "openai", - }, - {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, - { - "id": "gpt-4-turbo", - "object": "model", - "created": 1712361441, - "owned_by": "openai", - }, - { - "id": "text-davinci-003", - "object": "model", - "created": 1669599635, - "owned_by": "openai", - }, -] - -DUMMY_CHAT_RESPONSES = [ - "This is a mock response from the LLM server.", - "I'm a dummy AI assistant created for testing purposes.", - "This response is generated by a mock OpenAI-compatible server.", - "Hello! I'm responding with dummy data for benchmarking.", - "This is a simulated conversation response for testing.", -] - -DUMMY_COMPLETION_RESPONSES = [ - "This is a dummy text completion.", - "Here's some mock generated text.", - "This is a sample completion response.", - "Mock completion text for testing purposes.", - "Dummy text generated by the mock server.", -] - def generate_id(prefix: str = "chatcmpl") -> str: """Generate a unique ID for completions.""" @@ -71,23 +33,12 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_dummy_chat_response(config: AppModelConfig, seed: Optional[int] = None) -> str: - """Get a dummy chat response.""" - - if is_refusal(config, seed): - return config.refusal_text - - return random.choice(DUMMY_CHAT_RESPONSES) - - -def get_dummy_completion_response( - config: AppModelConfig, seed: Optional[int] = None -) -> str: - """Get a dummy completion response.""" - if is_refusal(config, seed): - return config.refusal_text +def get_response(config: AppModelConfig, seed: Optional[int] = None) -> str: + """Get a dummy /completion or /chat/completion response.""" - return random.choice(DUMMY_COMPLETION_RESPONSES) + if is_unsafe(config, seed): + return config.unsafe_text + return config.safe_text def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> float: @@ -111,12 +62,12 @@ def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> f return float(latency_seconds) -def is_refusal(config: AppModelConfig, seed: Optional[int] = None) -> bool: +def is_unsafe(config: AppModelConfig, seed: Optional[int] = None) -> bool: """Check if the model should return a refusal Very inefficient to generate each sample singly rather than in batch """ if seed: np.random.seed(seed) - refusal = np.random.binomial(n=1, p=config.refusal_probability, size=1) + refusal = np.random.binomial(n=1, p=config.unsafe_probability, size=1) return bool(refusal[0]) From f142c0f5c984a377e87ca3a2868c0403566d4621 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Thu, 2 Oct 2025 14:38:18 -0500 Subject: [PATCH 09/15] Add Mock configs for content-safety and App LLM --- .../configs/meta-llama-3.3-70b-instruct.yaml | 6 +++--- ...llama-3.1-nemoguard-8b-content-safety.yaml | 19 +++++-------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml index 32f987c20..0ed9d3857 100644 --- a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml +++ b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml @@ -2,7 +2,7 @@ model: "meta/llama-3.3-70b-instruct" unsafe_probability: 0.0 unsafe_text: "I can't help with that. Is there anything else I can assist you with?" safe_text : "I can provide information and help with a wide range of topics, from science and history to entertainment and culture. I can also help with language-related tasks, such as translation and text summarization. However, I can't assist with requests that involve harm or illegal activities." -latency_min_seconds: 0.3 -latency_max_seconds: 5 +latency_min_seconds: 4 +latency_max_seconds: 4 latency_mean_seconds: 4 -latency_std_seconds: 0.1 +latency_std_seconds: 0. diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml index fbded382a..a64b3e141 100644 --- a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -1,17 +1,8 @@ model: "nvidia/llama-3.1-nemoguard-8b-content-safety" unsafe_probability: 0.03 -unsafe_text: | - { - "User Safety": "unsafe", - "Response Safety": "unsafe", - "Safety Categories": "PII/Privacy" - } -safe_text : | - { - "User Safety": "safe", - "Response Safety": "safe", - } -latency_min_seconds: 0.1 -latency_max_seconds: 5 +unsafe_text: "{\"User Safety\": \"unsafe\", \"Response Safety\": \"unsafe\", \"Safety Categories\": \"Violence, Criminal Planning/Confessions\"} " +safe_text: "{\"User Safety\": \"safe\", \"Response Safety\": \"safe\"}" +latency_min_seconds: 0.4 +latency_max_seconds: 0.4 latency_mean_seconds: 0.4 -latency_std_seconds: 0.1 +latency_std_seconds: 0.0 From a18b5145d5ce7a032d606d88da1950d31966ba89 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Thu, 2 Oct 2025 16:50:11 -0500 Subject: [PATCH 10/15] Add async sleep statements and logging to record request time --- nemoguardrails/benchmark/__init__.py | 14 ++++ .../benchmark/mock_llm_server/api.py | 65 +++++++++++++++++-- .../benchmark/mock_llm_server/run_server.py | 31 +++++++-- 3 files changed, 99 insertions(+), 11 deletions(-) create mode 100644 nemoguardrails/benchmark/__init__.py diff --git a/nemoguardrails/benchmark/__init__.py b/nemoguardrails/benchmark/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index 80dedc728..4297c9e5c 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -14,10 +14,12 @@ # limitations under the License. +import asyncio +import logging import time from typing import Annotated, Optional, Union -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Request, Response from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config from nemoguardrails.benchmark.mock_llm_server.models import ( @@ -35,9 +37,28 @@ from nemoguardrails.benchmark.mock_llm_server.response_data import ( calculate_tokens, generate_id, + get_latency_seconds, get_response, ) +# Create a console logging handler +log = logging.getLogger(__name__) +log.setLevel(logging.INFO) # TODO Control this from the CLi args + +# Create a formatter to define the log message format +formatter = logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) + +# Create a console handler to print logs to the console +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) # DEBUG and higher will go to the console +console_handler.setFormatter(formatter) + +# Add console handler to logs +log.addHandler(console_handler) + + ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] @@ -60,6 +81,24 @@ def _validate_request_model( ) +@app.middleware("http") +async def log_http_duration(request: Request, call_next): + """ + Middleware to log incoming requests and their responses. + """ + request_time = time.time() + response = await call_next(request) + response_time = time.time() + + duration_seconds = response_time - request_time + log.info( + "Request finished: %s, took %.3f seconds", + response.status_code, + duration_seconds, + ) + return response + + @app.get("/") async def root(config: ModelConfigDep): """Root endpoint with basic server information.""" @@ -75,10 +114,14 @@ async def root(config: ModelConfigDep): @app.get("/v1/models", response_model=ModelsResponse) async def list_models(config: ModelConfigDep): """List available models.""" + log.debug("/v1/models request") + model = Model( id=config.model, object="model", created=int(time.time()), owned_by="system" ) - return ModelsResponse(object="list", data=[model]) + response = ModelsResponse(object="list", data=[model]) + log.debug("/v1/models response: %s", response) + return response @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @@ -86,11 +129,15 @@ async def chat_completions( request: ChatCompletionRequest, config: ModelConfigDep ) -> ChatCompletionResponse: """Create a chat completion.""" + + log.debug("/v1/chat/completions request: %s", request) + # Validate model exists _validate_request_model(config, request) # Generate dummy response response_content = get_response(config) + response_latency_seconds = get_latency_seconds(config, seed=12345) # Calculate token usage prompt_text = " ".join([msg.content for msg in request.messages]) @@ -122,7 +169,8 @@ async def chat_completions( total_tokens=prompt_tokens + completion_tokens, ), ) - + await asyncio.sleep(response_latency_seconds) + log.debug("/v1/chat/completions response: %s", response) return response @@ -132,6 +180,8 @@ async def completions( ) -> CompletionResponse: """Create a text completion.""" + log.debug("/v1/completions request: %s", request) + # Validate model exists _validate_request_model(config, request) @@ -143,6 +193,7 @@ async def completions( # Generate dummy response response_text = get_response(config) + response_latency_seconds = get_latency_seconds(config, seed=12345) # Calculate token usage prompt_tokens = calculate_tokens(prompt_text) @@ -171,10 +222,16 @@ async def completions( total_tokens=prompt_tokens + completion_tokens, ), ) + + await asyncio.sleep(response_latency_seconds) + log.debug("/v1/completions response: %s", response) return response @app.get("/health") async def health_check(): """Health check endpoint.""" - return {"status": "healthy", "timestamp": int(time.time())} + log.debug("/health request") + response = {"status": "healthy", "timestamp": int(time.time())} + log.debug("/health response: %s", response) + return response diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index 0d05756d2..c879e35a5 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -21,12 +21,29 @@ """ import argparse +import logging import sys import uvicorn +from uvicorn.logging import AccessFormatter from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config +# 1. Get a logger instance +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) # Set the lowest level to capture all messages + +# Set up formatter and direct it to the console +formatter = logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) # DEBUG and higher will go to the console +console_handler.setFormatter(formatter) + +# Add the console handler for logging +log.addHandler(console_handler) + def main(): parser = argparse.ArgumentParser(description="Run the Mock LLM Server") @@ -64,11 +81,11 @@ def main(): # Import the app after configuration is loaded. This caches the values in the app Dependencies from nemoguardrails.benchmark.mock_llm_server.api import app - print(f"Starting Mock LLM Server on {args.host}:{args.port}") - print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") - print(f"Health check at: http://{args.host}:{args.port}/health") - print(f"Model configuration: {model_config}") - print("Press Ctrl+C to stop the server") + log.info(f"Starting Mock LLM Server on {args.host}:{args.port}") + log.info(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") + log.info(f"Health check at: http://{args.host}:{args.port}/health") + log.info(f"Model configuration: {model_config}") + log.info("Press Ctrl+C to stop the server") try: uvicorn.run( @@ -79,9 +96,9 @@ def main(): log_level=args.log_level, ) except KeyboardInterrupt: - print("\nServer stopped by user") + log.info("\nServer stopped by user") except Exception as e: # pylint: disable=broad-except - print(f"Error starting server: {e}") + log.error(f"Error starting server: {e}") sys.exit(1) From 6beb888a26f49a7c0d54adbc3e28941dace6817d Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 3 Oct 2025 09:57:07 -0500 Subject: [PATCH 11/15] Change content-safety mock to have latency of 0.5s --- .../{content_safety => content_safety_colang1}/config.yml | 0 .../{content_safety => content_safety_colang1}/prompts.yml | 0 .../nvidia-llama-3.1-nemoguard-8b-content-safety.yaml | 6 +++--- 3 files changed, 3 insertions(+), 3 deletions(-) rename nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/{content_safety => content_safety_colang1}/config.yml (100%) rename nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/{content_safety => content_safety_colang1}/prompts.yml (100%) diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/config.yml similarity index 100% rename from nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/config.yml rename to nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/config.yml diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/prompts.yml similarity index 100% rename from nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety/prompts.yml rename to nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/prompts.yml diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml index a64b3e141..106545501 100644 --- a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -2,7 +2,7 @@ model: "nvidia/llama-3.1-nemoguard-8b-content-safety" unsafe_probability: 0.03 unsafe_text: "{\"User Safety\": \"unsafe\", \"Response Safety\": \"unsafe\", \"Safety Categories\": \"Violence, Criminal Planning/Confessions\"} " safe_text: "{\"User Safety\": \"safe\", \"Response Safety\": \"safe\"}" -latency_min_seconds: 0.4 -latency_max_seconds: 0.4 -latency_mean_seconds: 0.4 +latency_min_seconds: 0.5 +latency_max_seconds: 0.5 +latency_mean_seconds: 0.5 latency_std_seconds: 0.0 From c056b3b011d4be3d985da5e2ad61fd4248af40dd Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:03:58 -0500 Subject: [PATCH 12/15] Add unit-tests to mock llm --- .../benchmark/mock_llm_server/api.py | 17 +- .../benchmark/mock_llm_server/config.py | 32 +- .../configs/meta-llama-3.3-70b-instruct.env | 8 + .../configs/meta-llama-3.3-70b-instruct.yaml | 8 - ...-llama-3.1-nemoguard-8b-content-safety.env | 8 + ...llama-3.1-nemoguard-8b-content-safety.yaml | 8 - .../mock_llm_server/response_data.py | 8 +- .../benchmark/mock_llm_server/run_server.py | 25 +- tests/benchmark/mock_model_config.yaml | 11 +- tests/benchmark/test_api.py | 415 ++++++++++++ tests/benchmark/test_config.py | 61 ++ tests/benchmark/test_mock_llm_server.py | 614 ------------------ tests/benchmark/test_models.py | 340 ++++++++++ tests/benchmark/test_response_data.py | 498 ++++++++++++++ 14 files changed, 1385 insertions(+), 668 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env delete mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env delete mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml create mode 100644 tests/benchmark/test_api.py create mode 100644 tests/benchmark/test_config.py delete mode 100644 tests/benchmark/test_mock_llm_server.py create mode 100644 tests/benchmark/test_models.py create mode 100644 tests/benchmark/test_response_data.py diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index 4297c9e5c..5ed724ebf 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -21,7 +21,10 @@ from fastapi import Depends, FastAPI, HTTPException, Request, Response -from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config +from nemoguardrails.benchmark.mock_llm_server.config import ( # get_config, + ModelSettings, + get_settings, +) from nemoguardrails.benchmark.mock_llm_server.models import ( ChatCompletionChoice, ChatCompletionRequest, @@ -59,11 +62,11 @@ log.addHandler(console_handler) -ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] +ModelSettingsDep = Annotated[ModelSettings, Depends(get_settings)] def _validate_request_model( - config: ModelConfigDep, + config: ModelSettingsDep, request: Union[CompletionRequest, ChatCompletionRequest], ) -> None: """Check the Completion or Chat Completion `model` field is in our supported model list""" @@ -100,7 +103,7 @@ async def log_http_duration(request: Request, call_next): @app.get("/") -async def root(config: ModelConfigDep): +async def root(config: ModelSettingsDep): """Root endpoint with basic server information.""" return { "message": "Mock LLM Server", @@ -112,7 +115,7 @@ async def root(config: ModelConfigDep): @app.get("/v1/models", response_model=ModelsResponse) -async def list_models(config: ModelConfigDep): +async def list_models(config: ModelSettingsDep): """List available models.""" log.debug("/v1/models request") @@ -126,7 +129,7 @@ async def list_models(config: ModelConfigDep): @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions( - request: ChatCompletionRequest, config: ModelConfigDep + request: ChatCompletionRequest, config: ModelSettingsDep ) -> ChatCompletionResponse: """Create a chat completion.""" @@ -176,7 +179,7 @@ async def chat_completions( @app.post("/v1/completions", response_model=CompletionResponse) async def completions( - request: CompletionRequest, config: ModelConfigDep + request: CompletionRequest, config: ModelSettingsDep ) -> CompletionResponse: """Create a text completion.""" diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 44a725d1b..2fcd0166e 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -15,14 +15,23 @@ import os from functools import lru_cache +from pathlib import Path from typing import Any, Optional, Union import yaml from pydantic import BaseModel, Field -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) +CONFIG_FILE_ENV_VAR = "MOCK_LLM_CONFIG_FILE" +config_file_path = os.getenv(CONFIG_FILE_ENV_VAR, "model_settings.yml") +CONFIG_FILE = Path(config_file_path) -class AppModelConfig(BaseModel): + +class ModelSettings(BaseSettings): """Pydantic model to configure the Mock LLM Server.""" # Mandatory fields @@ -49,20 +58,11 @@ class AppModelConfig(BaseModel): default=0.1, description="Standard deviation of response time" ) - -settings: Optional[AppModelConfig] = None - - -def load_config(yaml_file: str) -> None: - """Load the Model configuration from YAML file, store in global `settings` var""" - global settings - with open(yaml_file, "r") as f: - config_data = yaml.safe_load(f) - settings = AppModelConfig(**config_data) + model_config = SettingsConfigDict(env_file=CONFIG_FILE) -def get_config() -> AppModelConfig: - """FastAPI Dependency to inject model configuration""" - if settings is None: - raise RuntimeError("No configuration loaded") +def get_settings() -> ModelSettings: + """Singleton-pattern to get settings once via lru_cache""" + settings = ModelSettings() + print("Returning ModelSettings: %s", settings) return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env new file mode 100644 index 000000000..208387602 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env @@ -0,0 +1,8 @@ +MODEL="meta/llama-3.3-70b-instruct" +UNSAFE_PROBABILITY=0.0 +UNSAFE_TEXT="I can't help with that. Is there anything else I can assist you with?" +SAFE_TEXT="I can provide information and help with a wide range of topics, from science and history to entertainment and culture. I can also help with language-related tasks, such as translation and text summarization. However, I can't assist with requests that involve harm or illegal activities." +LATENCY_MIN_SECONDS=4. +LATENCY_MAX_SECONDS=4. +LATENCY_MEAN_SECONDS=4. +LATENCY_STD_SECONDS=0. diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml deleted file mode 100644 index 0ed9d3857..000000000 --- a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.yaml +++ /dev/null @@ -1,8 +0,0 @@ -model: "meta/llama-3.3-70b-instruct" -unsafe_probability: 0.0 -unsafe_text: "I can't help with that. Is there anything else I can assist you with?" -safe_text : "I can provide information and help with a wide range of topics, from science and history to entertainment and culture. I can also help with language-related tasks, such as translation and text summarization. However, I can't assist with requests that involve harm or illegal activities." -latency_min_seconds: 4 -latency_max_seconds: 4 -latency_mean_seconds: 4 -latency_std_seconds: 0. diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env new file mode 100644 index 000000000..786d0685f --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env @@ -0,0 +1,8 @@ +MODEL="nvidia/llama-3.1-nemoguard-8b-content-safety" +UNSAFE_PROBABILITY=0.03 +UNSAFE_TEXT="{\"User Safety\": \"unsafe\", \"Response Safety\": \"unsafe\", \"Safety Categories\": \"Violence, Criminal Planning/Confessions\"} " +SAFE_TEXT="{\"User Safety\": \"safe\", \"Response Safety\": \"safe\"}" +LATENCY_MIN_SECONDS=0.5 +LATENCY_MAX_SECONDS=0.5 +LATENCY_MEAN_SECONDS=0.5 +LATENCY_STD_SECONDS=0.0 diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml deleted file mode 100644 index 106545501..000000000 --- a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml +++ /dev/null @@ -1,8 +0,0 @@ -model: "nvidia/llama-3.1-nemoguard-8b-content-safety" -unsafe_probability: 0.03 -unsafe_text: "{\"User Safety\": \"unsafe\", \"Response Safety\": \"unsafe\", \"Safety Categories\": \"Violence, Criminal Planning/Confessions\"} " -safe_text: "{\"User Safety\": \"safe\", \"Response Safety\": \"safe\"}" -latency_min_seconds: 0.5 -latency_max_seconds: 0.5 -latency_mean_seconds: 0.5 -latency_std_seconds: 0.0 diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index b58bb768a..c5914abcf 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -20,7 +20,7 @@ import numpy as np -from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings def generate_id(prefix: str = "chatcmpl") -> str: @@ -33,7 +33,7 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_response(config: AppModelConfig, seed: Optional[int] = None) -> str: +def get_response(config: ModelSettings, seed: Optional[int] = None) -> str: """Get a dummy /completion or /chat/completion response.""" if is_unsafe(config, seed): @@ -41,7 +41,7 @@ def get_response(config: AppModelConfig, seed: Optional[int] = None) -> str: return config.safe_text -def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> float: +def get_latency_seconds(config: ModelSettings, seed: Optional[int] = None) -> float: """Sample latency for this request using the model's config Very inefficient to generate each sample singly rather than in batch """ @@ -62,7 +62,7 @@ def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> f return float(latency_seconds) -def is_unsafe(config: AppModelConfig, seed: Optional[int] = None) -> bool: +def is_unsafe(config: ModelSettings, seed: Optional[int] = None) -> bool: """Check if the model should return a refusal Very inefficient to generate each sample singly rather than in batch """ diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index c879e35a5..cd1bfe789 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -22,12 +22,18 @@ import argparse import logging +import os import sys import uvicorn +import yaml from uvicorn.logging import AccessFormatter -from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config +from nemoguardrails.benchmark.mock_llm_server.config import ( + CONFIG_FILE_ENV_VAR, + ModelSettings, + get_settings, +) # 1. Get a logger instance log = logging.getLogger(__name__) @@ -74,26 +80,29 @@ def main(): args = parser.parse_args() - # Load model configuration - load_config(args.config_file) - model_config = get_config() + config_file = os.environ.get("CONFIG_FILE", args.config_file) + if not config_file: + raise RuntimeError( + "No CONFIG_FILE environment variable set, or --config-file CLI argument" + ) - # Import the app after configuration is loaded. This caches the values in the app Dependencies - from nemoguardrails.benchmark.mock_llm_server.api import app + log.info("Using config file: %s", config_file) + os.environ[CONFIG_FILE_ENV_VAR] = config_file log.info(f"Starting Mock LLM Server on {args.host}:{args.port}") log.info(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") log.info(f"Health check at: http://{args.host}:{args.port}/health") - log.info(f"Model configuration: {model_config}") + log.info(f"Serving model with config {config_file}") log.info("Press Ctrl+C to stop the server") try: uvicorn.run( - app=app, + "api:app", host=args.host, port=args.port, reload=args.reload, log_level=args.log_level, + env_file=config_file, ) except KeyboardInterrupt: log.info("\nServer stopped by user") diff --git a/tests/benchmark/mock_model_config.yaml b/tests/benchmark/mock_model_config.yaml index 384a988e5..9d3aa1bb0 100644 --- a/tests/benchmark/mock_model_config.yaml +++ b/tests/benchmark/mock_model_config.yaml @@ -1,3 +1,8 @@ -model: "mock_model" -refusal_probability: 0.01 -refusal_text: "I'm sorry, I can't help you with that request" +model: "nvidia/mock-content-safety" +unsafe_probability: 0.12 +unsafe_text: "{\"is_safe\": False}" +safe_text: "{\"is_safe\": True}" +latency_min_seconds: 0.01 +latency_max_seconds: 0.01 +latency_mean_seconds: 0.01 +latency_std_seconds: 0.0 diff --git a/tests/benchmark/test_api.py b/tests/benchmark/test_api.py new file mode 100644 index 000000000..5a4f75335 --- /dev/null +++ b/tests/benchmark/test_api.py @@ -0,0 +1,415 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import time + +import pytest +import yaml +from fastapi.testclient import TestClient + +from nemoguardrails.benchmark.mock_llm_server.api import app + +# from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings, get_settings + + +def get_test_settings(): + return ModelSettings( + model="gpt-3.5-turbo", + unsafe_probability=0.1, + unsafe_text="I cannot help with that request", + safe_text="This is a safe response", + latency_min_seconds=0, + latency_max_seconds=0, + latency_mean_seconds=0, + latency_std_seconds=0, + ) + + +@pytest.fixture +def client(): + """Create a test client.""" + app.dependency_overrides[get_settings] = get_test_settings + return TestClient(app) + + +def test_get_root_endpoint_server_data(client): + """Test GET / endpoint returns correct server details (not including model info)""" + + model_name = get_test_settings().model + + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Mock LLM Server" + assert data["version"] == "0.0.1" + assert ( + data["description"] + == f"OpenAI-compatible mock LLM server for model: {model_name}" + ) + assert data["endpoints"] == [ + "/v1/models", + "/v1/chat/completions", + "/v1/completions", + ] + + +def test_get_root_endpoint_model_data(client): + """Test GET / endpoint returns correct model details""" + + response = client.get("/") + data = response.json() + model_data = data["model_configuration"] + + expected_model_data = get_test_settings().model_dump() + assert model_data == expected_model_data + + +def test_get_health_endpoint(client): + """Test GET /health endpoint.""" + pre_request_time = int(time.time()) + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], int) + assert data["timestamp"] >= pre_request_time + + +def test_get_models_endpoint(client): + """Test GET /v1/models endpoint.""" + pre_request_time = int(time.time()) + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + + expected_model = get_test_settings().model_dump() + model = data["data"][0] + assert model["id"] == expected_model["model"] + assert model["object"] == "model" + assert isinstance(model["created"], int) + assert model["created"] >= pre_request_time + assert model["owned_by"] == "system" + + +class TestChatCompletionsEndpoint: + """Test the /v1/chat/completions endpoint.""" + + def test_chat_completions_basic(self, client): + """Test basic chat completion request.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == "gpt-3.5-turbo" + assert "id" in data + assert data["id"].startswith("chatcmpl-") + + def test_chat_completions_response_structure(self, client): + """Test the structure of chat completion response.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test message"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + # Check response structure + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert "message" in choice + assert choice["message"]["role"] == "assistant" + assert "content" in choice["message"] + assert choice["finish_reason"] == "stop" + + def test_chat_completions_usage(self, client): + """Test that usage information is included.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_chat_completions_multiple_choices(self, client): + """Test chat completion with n > 1.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "n": 3, + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + assert len(data["choices"]) == 3 + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + + def test_chat_completions_multiple_messages(self, client): + """Test chat completion with multiple messages.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_completions_invalid_model(self, client): + """Test chat completion with invalid model name.""" + payload = { + "model": "invalid-model", + "messages": [{"role": "user", "content": "Hello"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 400 + assert "not found" in response.json()["detail"].lower() + + def test_chat_completions_missing_messages(self, client): + """Test chat completion without messages field.""" + payload = { + "model": "gpt-3.5-turbo", + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 422 # Validation error + + def test_chat_completions_empty_messages(self, client): + """Test chat completion with empty messages list.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [], + } + response = client.post("/v1/chat/completions", json=payload) + # Should either be 422 or 200 depending on validation + # Let's check it doesn't crash + assert response.status_code in [200, 422] + + def test_chat_completions_latency(self, client): + """Test that chat completions have some latency.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + start = time.time() + response = client.post("/v1/chat/completions", json=payload) + duration = time.time() - start + + assert response.status_code == 200 + # Should have some latency (at least minimal) + assert duration >= 0.0 + + +class TestCompletionsEndpoint: + """Test the /v1/completions endpoint.""" + + def test_completions_basic(self, client): + """Test basic completion request.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Once upon a time", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "text_completion" + assert data["model"] == "gpt-3.5-turbo" + assert data["id"].startswith("cmpl-") + + def test_completions_response_structure(self, client): + """Test the structure of completion response.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert "text" in choice + assert isinstance(choice["text"], str) + assert choice["finish_reason"] == "stop" + assert choice["logprobs"] is None + + def test_completions_string_prompt(self, client): + """Test completion with string prompt.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Single string prompt", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + + def test_completions_list_prompt(self, client): + """Test completion with list of prompts.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": ["Prompt 1", "Prompt 2", "Prompt 3"], + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + data = response.json() + # Should still return a response (joined prompts) + assert "choices" in data + + def test_completions_multiple_choices(self, client): + """Test completion with n > 1.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test", + "n": 5, + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert len(data["choices"]) == 5 + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + + def test_completions_usage(self, client): + """Test that usage information is included.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_completions_invalid_model(self, client): + """Test completion with invalid model name.""" + payload = { + "model": "wrong-model", + "prompt": "Test", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 400 + + def test_completions_missing_prompt(self, client): + """Test completion without prompt field.""" + payload = { + "model": "gpt-3.5-turbo", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 422 # Validation error + + +class TestMiddleware: + """Test the HTTP logging middleware.""" + + def test_middleware_logs_request(self, client): + """Test that middleware processes requests.""" + # The middleware should not affect response + response = client.get("/health") + assert response.status_code == 200 + + def test_middleware_with_post(self, client): + """Test middleware with POST request.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + + +class TestValidateRequestModel: + """Test the _validate_request_model function.""" + + def test_validate_request_model_valid(self, client): + """Test validation with correct model.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + + def test_validate_request_model_invalid(self, client): + """Test validation with incorrect model.""" + payload = { + "model": "nonexistent-model", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 400 + assert "not found" in response.json()["detail"].lower() + assert "gpt-3.5-turbo" in response.json()["detail"] + + +class TestResponseContent: + """Test that responses contain expected content.""" + + def test_chat_response_content_type(self, client): + """Test that response contains either safe or unsafe text.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + content = data["choices"][0]["message"]["content"] + # Should be one of the configured responses + assert content in ["This is a safe response", "I cannot help with that request"] + + def test_completion_response_content_type(self, client): + """Test that completion response contains expected text.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + text = data["choices"][0]["text"] + # Should be one of the configured responses + assert text in ["This is a safe response", "I cannot help with that request"] diff --git a/tests/benchmark/test_config.py b/tests/benchmark/test_config.py new file mode 100644 index 000000000..d97d7df6d --- /dev/null +++ b/tests/benchmark/test_config.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import yaml + +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings + + +class TestAppModelConfig: + """Test the AppModelConfig Pydantic model.""" + + def test_app_model_config_with_defaults(self): + """Test creating AppModelConfig with default values.""" + config = ModelSettings( + model="test-model", + unsafe_text="Unsafe", + safe_text="Safe", + ) + # Check defaults + assert config.unsafe_probability == 0.1 + assert config.latency_min_seconds == 0.1 + assert config.latency_max_seconds == 5 + assert config.latency_mean_seconds == 0.5 + assert config.latency_std_seconds == 0.1 + + def test_app_model_config_missing_required_field(self): + """Test that missing required fields raise validation error.""" + with pytest.raises(Exception): # Pydantic ValidationError + ModelSettings( # type: ignore (Test is meant to check missing mandatory field) + model="test-model", + unsafe_text="Unsafe", + # Missing safe_text + ) + + def test_app_model_config_model_serialization(self): + """Test that AppModelConfig can be serialized to dict.""" + config = ModelSettings( + model="test-model", + unsafe_text="Unsafe", + safe_text="Safe", + ) + config_dict = config.model_dump() + assert isinstance(config_dict, dict) + assert config_dict["model"] == "test-model" + assert config_dict["safe_text"] == "Safe" diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py deleted file mode 100644 index 552eb57e1..000000000 --- a/tests/benchmark/test_mock_llm_server.py +++ /dev/null @@ -1,614 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Unit tests for the Mock LLM FastAPI Server. - -This module contains comprehensive tests for all endpoints and edge cases -of the OpenAI-compatible mock LLM server. -""" - -import json -import os -import time -from typing import Any, Dict, List -from unittest.mock import patch - -import pytest -from fastapi.testclient import TestClient - -from nemoguardrails.benchmark.mock_llm_server.api import app -from nemoguardrails.benchmark.mock_llm_server.config import ( - AppModelConfig, - get_config, - load_config, -) -from nemoguardrails.benchmark.mock_llm_server.response_data import ( - DUMMY_CHAT_RESPONSES, - DUMMY_COMPLETION_RESPONSES, - DUMMY_MODELS, - calculate_tokens, - generate_id, - get_dummy_chat_response, - get_dummy_completion_response, -) - -RANDOM_SEED = 12345 -REFUSAL_TEXT = "I'm sorry Dave, I'm afraid I can't do that" -NO_REFUSAL_CONFIG = AppModelConfig( - model="mock-model", - refusal_text=REFUSAL_TEXT, - refusal_probability=0.0, -) - -ALL_REFUSAL_CONFIG = AppModelConfig( - model="mock-model", - refusal_text=REFUSAL_TEXT, - refusal_probability=1.0, -) - - -class TestMockLLMServer: - """Test class for the Mock LLM Server.""" - - @pytest.fixture - def client(self): - """Create a test client for the FastAPI app.""" - return TestClient(app) - - @pytest.fixture - def valid_chat_request(self): - """Sample valid chat completion request.""" - return { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "max_tokens": 50, - "temperature": 0.7, - } - - @pytest.fixture - def valid_completion_request(self): - """Sample valid text completion request.""" - return { - "model": "text-davinci-003", - "prompt": "The capital of France is", - "max_tokens": 10, - "temperature": 0.8, - } - - # Root endpoint tests - def test_root_endpoint(self, client): - """Test the root endpoint returns correct information.""" - - mock_config = AppModelConfig( - model="mock_config_model_name", - refusal_text="I'm afraid I can't do that, Dave", - ) - - def override_get_config(): - return mock_config - - app.dependency_overrides[get_config] = override_get_config - - response = client.get("/") - assert response.status_code == 200 - - data = response.json() - assert data["message"] == "Mock LLM Server" - assert data["version"] == "0.0.1" - assert "description" in data - assert "/v1/models" in data["endpoints"] - assert "/v1/chat/completions" in data["endpoints"] - assert "/v1/completions" in data["endpoints"] - assert data["model_configuration"]["model"] == mock_config.model - assert data["model_configuration"]["refusal_text"] == mock_config.refusal_text - - # Health check tests - def test_health_check(self, client): - """Test the health check endpoint.""" - response = client.get("/health") - assert response.status_code == 200 - - data = response.json() - assert data["status"] == "healthy" - assert "timestamp" in data - assert isinstance(data["timestamp"], int) - - # Models endpoint tests - def test_list_models(self, client): - """Test the models listing endpoint.""" - response = client.get("/v1/models") - assert response.status_code == 200 - - data = response.json() - assert data["object"] == "list" - assert isinstance(data["data"], list) - assert len(data["data"]) == len(DUMMY_MODELS) - - # Check first model structure - model = data["data"][0] - assert "id" in model - assert "object" in model - assert "created" in model - assert "owned_by" in model - assert model["object"] == "model" - - def test_models_contain_expected_models(self, client): - """Test that all expected models are returned.""" - response = client.get("/v1/models") - data = response.json() - - model_ids = [model["id"] for model in data["data"]] - expected_ids = [model["id"] for model in DUMMY_MODELS] - - assert set(model_ids) == set(expected_ids) - - # Chat completions tests - def test_chat_completions_success(self, client, valid_chat_request): - """Test successful chat completion request.""" - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 200 - - data = response.json() - assert data["object"] == "chat.completion" - assert data["model"] == valid_chat_request["model"] - assert "id" in data - assert "created" in data - assert isinstance(data["created"], int) - - # Check choices - assert "choices" in data - assert len(data["choices"]) == 1 - choice = data["choices"][0] - assert choice["index"] == 0 - assert choice["finish_reason"] == "stop" - assert "message" in choice - assert choice["message"]["role"] == "assistant" - assert isinstance(choice["message"]["content"], str) - - # Check usage - assert "usage" in data - usage = data["usage"] - assert "prompt_tokens" in usage - assert "completion_tokens" in usage - assert "total_tokens" in usage - assert ( - usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - ) - - def test_chat_completions_multiple_choices(self, client, valid_chat_request): - """Test chat completion with multiple choices.""" - valid_chat_request["n"] = 3 - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 200 - - data = response.json() - assert len(data["choices"]) == 3 - - for i, choice in enumerate(data["choices"]): - assert choice["index"] == i - assert choice["finish_reason"] == "stop" - - def test_chat_completions_invalid_model(self, client, valid_chat_request): - """Test chat completion with invalid model.""" - valid_chat_request["model"] = "invalid-model" - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 400 - - data = response.json() - assert "detail" in data - assert "invalid-model" in data["detail"] - assert "not found" in data["detail"] - - def test_chat_completions_empty_messages(self, client): - """Test chat completion with empty messages.""" - request_data = { - "model": "gpt-3.5-turbo", - "messages": [], - } - response = client.post("/v1/chat/completions", json=request_data) - # Note: The server currently accepts empty messages and processes them - # This may be acceptable behavior for a mock server - assert response.status_code in [ - 200, - 422, - ] # Allow both success and validation error - - def test_chat_completions_invalid_message_format(self, client): - """Test chat completion with invalid message format.""" - request_data = { - "model": "gpt-3.5-turbo", - "messages": [{"invalid": "format"}], - } - response = client.post("/v1/chat/completions", json=request_data) - assert response.status_code == 422 # Validation error - - def test_chat_completions_parameter_validation(self, client, valid_chat_request): - """Test parameter validation for chat completions.""" - # Test max_tokens validation - valid_chat_request["max_tokens"] = 0 - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 422 - - # Test temperature validation - valid_chat_request["max_tokens"] = 50 - valid_chat_request["temperature"] = 3.0 # Out of range - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 422 - - # Test n validation - valid_chat_request["temperature"] = 0.7 - valid_chat_request["n"] = 200 # Out of range - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 422 - - def test_chat_completions_optional_parameters(self, client): - """Test chat completion with various optional parameters.""" - request_data = { - "model": "gpt-4", - "messages": [{"role": "user", "content": "Test message"}], - "max_tokens": 100, - "temperature": 0.5, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.2, - "stop": ["\\n"], - "user": "test-user", - } - response = client.post("/v1/chat/completions", json=request_data) - assert response.status_code == 200 - - # Text completions tests - def test_completions_success(self, client, valid_completion_request): - """Test successful text completion request.""" - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 200 - - data = response.json() - assert data["object"] == "text_completion" - assert data["model"] == valid_completion_request["model"] - assert "id" in data - assert "created" in data - - # Check choices - assert "choices" in data - assert len(data["choices"]) == 1 - choice = data["choices"][0] - assert choice["index"] == 0 - assert choice["finish_reason"] == "stop" - assert "text" in choice - assert isinstance(choice["text"], str) - - # Check usage - assert "usage" in data - usage = data["usage"] - assert "prompt_tokens" in usage - assert "completion_tokens" in usage - assert "total_tokens" in usage - - def test_completions_list_prompt(self, client): - """Test text completion with list prompt.""" - request_data = { - "model": "text-davinci-003", - "prompt": ["First prompt", "Second prompt"], - "max_tokens": 10, - } - response = client.post("/v1/completions", json=request_data) - assert response.status_code == 200 - - data = response.json() - assert data["object"] == "text_completion" - - def test_completions_invalid_model(self, client, valid_completion_request): - """Test text completion with invalid model.""" - valid_completion_request["model"] = "non-existent-model" - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 400 - - def test_completions_multiple_choices(self, client, valid_completion_request): - """Test text completion with multiple choices.""" - valid_completion_request["n"] = 2 - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 200 - - data = response.json() - assert len(data["choices"]) == 2 - - def test_completions_parameter_validation(self, client, valid_completion_request): - """Test parameter validation for text completions.""" - # Test max_tokens validation - valid_completion_request["max_tokens"] = -1 - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 422 - - # Test temperature validation - valid_completion_request["max_tokens"] = 10 - valid_completion_request["temperature"] = -1.0 - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 422 - - def test_completions_optional_parameters(self, client): - """Test text completion with various optional parameters.""" - request_data = { - "model": "gpt-3.5-turbo", - "prompt": "Test prompt", - "max_tokens": 50, - "temperature": 0.8, - "top_p": 0.95, - "n": 1, - "logprobs": 1, - "echo": True, - "stop": ["\\n", "."], - "presence_penalty": -0.5, - "frequency_penalty": 0.3, - "best_of": 2, - "user": "test-user-2", - } - response = client.post("/v1/completions", json=request_data) - assert response.status_code == 200 - - # Helper function tests - def test_generate_id_default(self): - """Test ID generation with default prefix.""" - id1 = generate_id() - id2 = generate_id() - - assert id1.startswith("chatcmpl-") - assert id2.startswith("chatcmpl-") - assert id1 != id2 # Should be unique - assert len(id1) == len("chatcmpl-") + 8 # prefix + 8 hex chars - - def test_generate_id_custom_prefix(self): - """Test ID generation with custom prefix.""" - custom_id = generate_id("cmpl") - assert custom_id.startswith("cmpl-") - assert len(custom_id) == len("cmpl-") + 8 - - def test_calculate_tokens(self): - """Test token calculation function.""" - # Test basic calculation - assert calculate_tokens("") == 1 # Minimum 1 token - assert calculate_tokens("a") == 1 - assert calculate_tokens("abcd") == 1 - assert calculate_tokens("abcde") == 1 # 5 chars = 1 token (rounded down) - assert calculate_tokens("abcdefgh") == 2 # 8 chars = 2 tokens - - # Test longer text - long_text = "This is a longer text with multiple words and characters." - expected_tokens = max(1, len(long_text) // 4) - assert calculate_tokens(long_text) == expected_tokens - - def test_get_dummy_completion_response_refusal(self): - """Test response generation with P = 1.0 of refusal""" - response = get_dummy_completion_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) - assert response == ALL_REFUSAL_CONFIG.refusal_text - - def test_get_dummy_chat_response_refusal(self): - """Test response generation with P = 1.0 of refusal""" - response = get_dummy_chat_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) - assert response == ALL_REFUSAL_CONFIG.refusal_text - - def test_get_dummy_completion_response_no_refusal(self): - """Test /completion response generation with P = 0.0 of refusal""" - response = get_dummy_completion_response(NO_REFUSAL_CONFIG) - assert response in set(DUMMY_COMPLETION_RESPONSES) - - def test_get_dummy_chat_response_no_refusal(self): - """Test /chat/completion response with P = 0.0 of refusal.""" - response = get_dummy_chat_response(NO_REFUSAL_CONFIG) - assert response in set(DUMMY_CHAT_RESPONSES) - - # Edge cases and error handling - def test_missing_required_fields_chat(self, client): - """Test chat completion with missing required fields.""" - # Missing model - response = client.post("/v1/chat/completions", json={"messages": []}) - assert response.status_code == 422 - - # Missing messages - response = client.post("/v1/chat/completions", json={"model": "gpt-3.5-turbo"}) - assert response.status_code == 422 - - def test_missing_required_fields_completion(self, client): - """Test text completion with missing required fields.""" - # Missing model - response = client.post("/v1/completions", json={"prompt": "test"}) - assert response.status_code == 422 - - # Missing prompt - response = client.post("/v1/completions", json={"model": "gpt-3.5-turbo"}) - assert response.status_code == 422 - - def test_invalid_json(self, client): - """Test endpoints with invalid JSON.""" - response = client.post( - "/v1/chat/completions", - content="invalid json", - headers={"Content-Type": "application/json"}, - ) - assert response.status_code == 422 - - def test_empty_request_body(self, client): - """Test endpoints with empty request body.""" - response = client.post("/v1/chat/completions", json={}) - assert response.status_code == 422 - - response = client.post("/v1/completions", json={}) - assert response.status_code == 422 - - # Content validation tests - def test_chat_message_content_types(self, client): - """Test chat completion with different message content types.""" - # Test with multiple messages - request_data = { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "How are you?"}, - ], - } - response = client.post("/v1/chat/completions", json=request_data) - assert response.status_code == 200 - - def test_response_structure_consistency(self, client, valid_chat_request): - """Test that response structure is consistent across calls.""" - response1 = client.post("/v1/chat/completions", json=valid_chat_request) - response2 = client.post("/v1/chat/completions", json=valid_chat_request) - - assert response1.status_code == 200 - assert response2.status_code == 200 - - data1 = response1.json() - data2 = response2.json() - - # Structure should be the same - assert set(data1.keys()) == set(data2.keys()) - assert data1["object"] == data2["object"] - assert data1["model"] == data2["model"] - - # IDs should be different - assert data1["id"] != data2["id"] - - def test_concurrent_requests(self, client, valid_chat_request): - """Test handling of concurrent requests.""" - import threading - import time - - results = [] - - def make_request(): - response = client.post("/v1/chat/completions", json=valid_chat_request) - results.append(response.status_code) - - # Create multiple threads - threads = [] - for _ in range(5): - thread = threading.Thread(target=make_request) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # All requests should be successful - assert all(status == 200 for status in results) - assert len(results) == 5 - - # Performance and load tests - def test_response_time_reasonable(self, client, valid_chat_request): - """Test that response times are reasonable.""" - start_time = time.time() - response = client.post("/v1/chat/completions", json=valid_chat_request) - end_time = time.time() - - assert response.status_code == 200 - assert (end_time - start_time) < 1.0 # Should respond within 1 second - - def test_large_prompt_handling(self, client): - """Test handling of large prompts.""" - large_prompt = "A" * 10000 # 10K characters - request_data = { - "model": "text-davinci-003", - "prompt": large_prompt, - "max_tokens": 10, - } - response = client.post("/v1/completions", json=request_data) - assert response.status_code == 200 - - data = response.json() - # Token calculation should handle large text - assert data["usage"]["prompt_tokens"] > 1000 - - # Mock and patch tests - @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_chat_response") - def test_chat_completion_response_mocking( - self, mock_response, client, valid_chat_request - ): - """Test mocking of chat response generation.""" - expected_response = "Mocked response for testing chat completions" - mock_response.return_value = expected_response - - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 200 - - data = response.json() - assert data["choices"][0]["message"]["content"] == expected_response - mock_response.assert_called_once() - - @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_completion_response") - def test_completion_response_mocking( - self, mock_response, client, valid_completion_request - ): - """Test mocking of chat response generation.""" - expected_response = "Mocked response to check completion responses" - mock_response.return_value = expected_response - - response = client.post("/v1/completions", json=valid_completion_request) - assert response.status_code == 200 - - data = response.json() - assert data["choices"][0]["text"] == expected_response - mock_response.assert_called_once() - - @patch("time.time") - def test_timestamp_consistency(self, mock_time, client, valid_chat_request): - """Test that timestamps are generated correctly.""" - mock_time.return_value = 1234567890 - - response = client.post("/v1/chat/completions", json=valid_chat_request) - assert response.status_code == 200 - - data = response.json() - assert data["created"] == 1234567890 - - # Documentation and OpenAPI tests - def test_openapi_docs_available(self, client): - """Test that OpenAPI documentation is available.""" - response = client.get("/docs") - assert response.status_code == 200 - - response = client.get("/openapi.json") - assert response.status_code == 200 - - openapi_data = response.json() - assert "openapi" in openapi_data - assert "paths" in openapi_data - assert "/v1/models" in openapi_data["paths"] - assert "/v1/chat/completions" in openapi_data["paths"] - assert "/v1/completions" in openapi_data["paths"] - - def test_read_root_with_mock_config(self): - """Tests load_config method correctly populates the `settings` global variable""" - yaml_file = os.path.join(os.path.dirname(__file__), "mock_model_config.yaml") - - # Make sure settings is empty to start with, load and check it's populated - load_config(yaml_file) - config = get_config() - assert config is not None - - # Now check the contents against `mock_model_config.yaml` - assert isinstance(config, AppModelConfig) - assert config.model == "mock_model" - assert config.refusal_probability == 0.01 - assert config.refusal_text == "I'm sorry, I can't help you with that request" - - @patch("nemoguardrails.benchmark.mock_llm_server.config.settings", None) - def test_get_config_raises_exception(self): - """Check if we call `get_config()` without settings set we raise an exception""" - with pytest.raises(RuntimeError, match="No configuration loaded"): - get_config() diff --git a/tests/benchmark/test_models.py b/tests/benchmark/test_models.py new file mode 100644 index 000000000..fd6d4e979 --- /dev/null +++ b/tests/benchmark/test_models.py @@ -0,0 +1,340 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) + + +class TestMessage: + """Test the Message model.""" + + def test_message_creation(self): + """Test creating a Message.""" + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + + def test_message_missing_fields(self): + """Test that missing required fields raise validation error.""" + with pytest.raises(ValidationError): + Message(role="user") # Missing content + + with pytest.raises(ValidationError): + Message(content="Hello") # Missing role + + +class TestChatCompletionRequest: + """Test the ChatCompletionRequest model.""" + + def test_chat_completion_request_minimal(self): + """Test creating ChatCompletionRequest with minimal fields.""" + req = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hello")], + ) + assert req.model == "gpt-3.5-turbo" + assert len(req.messages) == 1 + assert req.temperature == 1.0 # Default + assert req.n == 1 # Default + assert req.stream is False # Default + + def test_chat_completion_request_validation(self): + """Test validation of ChatCompletionRequest fields.""" + # Test temperature bounds + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + temperature=3.0, # > 2.0 + ) + + # Test n bounds + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + n=200, # > 128 + ) + + def test_chat_completion_request_stop_variants(self): + """Test stop parameter can be string or list.""" + req1 = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + stop="END", + ) + assert req1.stop == "END" + + req2 = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + stop=["END", "STOP"], + ) + assert req2.stop == ["END", "STOP"] + + +class TestCompletionRequest: + """Test the CompletionRequest model.""" + + def test_completion_request_minimal(self): + """Test creating CompletionRequest with minimal fields.""" + req = CompletionRequest( + model="text-davinci-003", + prompt="Hello", + ) + assert req.model == "text-davinci-003" + assert req.prompt == "Hello" + assert req.max_tokens == 16 # Default + assert req.temperature == 1.0 # Default + + def test_completion_request_prompt_string(self): + """Test CompletionRequest with string prompt.""" + req = CompletionRequest(model="test-model", prompt="Test prompt") + assert req.prompt == "Test prompt" + assert isinstance(req.prompt, str) + + def test_completion_request_prompt_list(self): + """Test CompletionRequest with list of prompts.""" + req = CompletionRequest(model="test-model", prompt=["Prompt 1", "Prompt 2"]) + assert req.prompt == ["Prompt 1", "Prompt 2"] + assert isinstance(req.prompt, list) + + def test_completion_request_all_fields(self): + """Test creating CompletionRequest with all fields.""" + req = CompletionRequest( + model="text-davinci-003", + prompt=["Prompt 1", "Prompt 2"], + max_tokens=50, + temperature=0.8, + top_p=0.95, + n=3, + stream=True, + logprobs=5, + echo=True, + stop=["STOP"], + presence_penalty=0.6, + frequency_penalty=0.4, + best_of=2, + logit_bias={"token1": 1.0}, + user="user456", + ) + assert req.model == "text-davinci-003" + assert req.prompt == ["Prompt 1", "Prompt 2"] + assert req.max_tokens == 50 + assert req.logprobs == 5 + assert req.echo is True + assert req.best_of == 2 + + def test_completion_request_validation(self): + """Test validation of CompletionRequest fields.""" + # Test logprobs bounds + with pytest.raises(ValidationError): + CompletionRequest( + model="test-model", + prompt="Hi", + logprobs=10, # > 5 + ) + + +class TestUsage: + """Test the Usage model.""" + + def test_usage_creation(self): + """Test creating a Usage model.""" + usage = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 20 + assert usage.total_tokens == 30 + + def test_usage_missing_fields(self): + """Test that missing fields raise validation error.""" + with pytest.raises(ValidationError): + Usage(prompt_tokens=10, completion_tokens=20) # Missing total_tokens + + +class TestChatCompletionChoice: + """Test the ChatCompletionChoice model.""" + + def test_chat_completion_choice_creation(self): + """Test creating a ChatCompletionChoice.""" + choice = ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Response"), + finish_reason="stop", + ) + assert choice.index == 0 + assert choice.message.role == "assistant" + assert choice.message.content == "Response" + assert choice.finish_reason == "stop" + + +class TestCompletionChoice: + """Test the CompletionChoice model.""" + + def test_completion_choice_creation(self): + """Test creating a CompletionChoice.""" + choice = CompletionChoice( + text="Generated text", index=0, logprobs=None, finish_reason="length" + ) + assert choice.text == "Generated text" + assert choice.index == 0 + assert choice.logprobs is None + assert choice.finish_reason == "length" + + def test_completion_choice_with_logprobs(self): + """Test CompletionChoice with logprobs.""" + choice = CompletionChoice( + text="Text", + index=0, + logprobs={"tokens": ["test"], "token_logprobs": [-0.5]}, + finish_reason="stop", + ) + assert choice.logprobs is not None + assert "tokens" in choice.logprobs + + +class TestChatCompletionResponse: + """Test the ChatCompletionResponse model.""" + + def test_chat_completion_response_creation(self): + """Test creating a ChatCompletionResponse.""" + response = ChatCompletionResponse( + id="chatcmpl-123", + object="chat.completion", + created=1234567890, + model="gpt-3.5-turbo", + choices=[ + ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Hello!"), + finish_reason="stop", + ) + ], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + assert response.id == "chatcmpl-123" + assert response.object == "chat.completion" + assert response.created == 1234567890 + assert response.model == "gpt-3.5-turbo" + assert len(response.choices) == 1 + assert response.usage.total_tokens == 15 + + def test_chat_completion_response_multiple_choices(self): + """Test ChatCompletionResponse with multiple choices.""" + response = ChatCompletionResponse( + id="chatcmpl-456", + object="chat.completion", + created=1234567890, + model="gpt-4", + choices=[ + ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Response 1"), + finish_reason="stop", + ), + ChatCompletionChoice( + index=1, + message=Message(role="assistant", content="Response 2"), + finish_reason="stop", + ), + ], + usage=Usage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + assert len(response.choices) == 2 + assert response.choices[0].message.content == "Response 1" + assert response.choices[1].message.content == "Response 2" + + +class TestCompletionResponse: + """Test the CompletionResponse model.""" + + def test_completion_response_creation(self): + """Test creating a CompletionResponse.""" + response = CompletionResponse( + id="cmpl-789", + object="text_completion", + created=1234567890, + model="text-davinci-003", + choices=[ + CompletionChoice( + text="Completed text", index=0, logprobs=None, finish_reason="stop" + ) + ], + usage=Usage(prompt_tokens=15, completion_tokens=10, total_tokens=25), + ) + assert response.id == "cmpl-789" + assert response.object == "text_completion" + assert response.created == 1234567890 + assert response.model == "text-davinci-003" + assert len(response.choices) == 1 + assert response.usage.total_tokens == 25 + + +class TestModel: + """Test the Model model.""" + + def test_model_creation(self): + """Test creating a Model.""" + model = Model( + id="gpt-3.5-turbo", object="model", created=1677610602, owned_by="openai" + ) + assert model.id == "gpt-3.5-turbo" + assert model.object == "model" + assert model.created == 1677610602 + assert model.owned_by == "openai" + + +class TestModelsResponse: + """Test the ModelsResponse model.""" + + def test_models_response_creation(self): + """Test creating a ModelsResponse.""" + response = ModelsResponse( + object="list", + data=[ + Model( + id="gpt-3.5-turbo", + object="model", + created=1677610602, + owned_by="openai", + ), + Model( + id="gpt-4", object="model", created=1687882410, owned_by="openai" + ), + ], + ) + assert response.object == "list" + assert len(response.data) == 2 + assert response.data[0].id == "gpt-3.5-turbo" + assert response.data[1].id == "gpt-4" + + def test_models_response_empty(self): + """Test ModelsResponse with no models.""" + response = ModelsResponse(object="list", data=[]) + assert response.object == "list" + assert len(response.data) == 0 diff --git a/tests/benchmark/test_response_data.py b/tests/benchmark/test_response_data.py new file mode 100644 index 000000000..0207d3b11 --- /dev/null +++ b/tests/benchmark/test_response_data.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import yaml + +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings +from nemoguardrails.benchmark.mock_llm_server.models import Model +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + calculate_tokens, + generate_id, + get_latency_seconds, + get_response, + is_unsafe, +) + + +class TestGenerateId: + """Test the generate_id function.""" + + def test_generate_id_default_prefix(self): + """Test generating ID with default prefix.""" + id1 = generate_id() + assert id1.startswith("chatcmpl-") + # ID should be in format: prefix-{8 hex chars} + assert len(id1) == len("chatcmpl-") + 8 + + def test_generate_id_custom_prefix(self): + """Test generating ID with custom prefix.""" + id1 = generate_id("cmpl") + assert id1.startswith("cmpl-") + assert len(id1) == len("cmpl-") + 8 + + def test_generate_id_format(self): + """Test that generated IDs have correct format.""" + id1 = generate_id("test") + # Should match pattern: prefix-{8 hex chars} + pattern = r"test-[0-9a-f]{8}" + assert re.match(pattern, id1) + + +class TestCalculateTokens: + """Test the calculate_tokens function.""" + + def test_calculate_tokens_empty_string(self): + """Test calculating tokens for empty string.""" + tokens = calculate_tokens("") + assert tokens == 1 # Returns at least 1 + + def test_calculate_tokens_short_text(self): + """Test calculating tokens for short text.""" + tokens = calculate_tokens("Hi") + # 2 chars / 4 = 0, but max(1, 0) = 1 + assert tokens == 1 + + def test_calculate_tokens_exact_division(self): + """Test calculating tokens for text divisible by 4.""" + text = "a" * 20 # 20 chars / 4 = 5 tokens + tokens = calculate_tokens(text) + assert tokens == 5 + + def test_calculate_tokens_with_remainder(self): + """Test calculating tokens for text with remainder.""" + text = "a" * 19 # 19 chars / 4 = 4 (integer division) + tokens = calculate_tokens(text) + assert tokens == 4 + + def test_calculate_tokens_long_text(self): + """Test calculating tokens for long text.""" + text = "This is a longer text that should have multiple tokens." * 10 + tokens = calculate_tokens(text) + expected = max(1, len(text) // 4) + assert tokens == expected + + def test_calculate_tokens_unicode(self): + """Test calculating tokens with unicode characters.""" + text = "Hello δΈ–η•Œ 🌍" + tokens = calculate_tokens(text) + assert tokens >= 1 + assert tokens == max(1, len(text) // 4) + + +@pytest.fixture +def model_settings() -> ModelSettings: + """Generate config data for use in response generation""" + settings = ModelSettings( + model="gpt-4o", + unsafe_probability=0.5, + unsafe_text="Sorry Dave, I'm afraid I can't do that.", + safe_text="I'm an AI assistant and am happy to help", + latency_min_seconds=0.2, + latency_max_seconds=1.0, + latency_mean_seconds=0.5, + latency_std_seconds=0.1, + ) + return settings + + +@pytest.fixture +def random_seed() -> int: + """Return a fixed seed number for all tests""" + return 12345 + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.binomial") +def test_is_unsafe_mocks_no_seed( + mock_binomial: MagicMock, mock_seed: MagicMock, model_settings: ModelSettings +): + """Check `is_unsafe()` calls the correct numpy functions""" + mock_binomial.return_value = [True] + + response = is_unsafe(model_settings) + + assert response == True + assert mock_seed.call_count == 0 + assert mock_binomial.call_count == 1 + mock_binomial.assert_called_once_with( + n=1, p=model_settings.unsafe_probability, size=1 + ) + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.binomial") +def test_is_unsafe_mocks_with_seed( + mock_binomial, mock_seed, model_settings: ModelSettings, random_seed: int +): + """Check `is_unsafe()` calls the correct numpy functions""" + mock_binomial.return_value = [False] + + response = is_unsafe(model_settings, random_seed) + + assert response == False + assert mock_seed.call_count == 1 + assert mock_binomial.call_count == 1 + mock_binomial.assert_called_once_with( + n=1, p=model_settings.unsafe_probability, size=1 + ) + + +def test_is_unsafe_prob_one(model_settings: ModelSettings): + """Check `is_unsafe()` with probability of 1 returns True""" + + model_settings.unsafe_probability = 1.0 + response = is_unsafe(model_settings) + assert response == True + + +def test_is_unsafe_prob_zero(model_settings: ModelSettings): + """Check `is_unsafe()` with probability of 1 returns True""" + + model_settings.unsafe_probability = 0.0 + response = is_unsafe(model_settings) + assert response == False + + +def test_get_response_safe(model_settings: ModelSettings): + """Check we get the safe response with is_unsafe returns False""" + with patch( + "nemoguardrails.benchmark.mock_llm_server.response_data.is_unsafe" + ) as mock_is_unsafe: + mock_is_unsafe.return_value = False + response = get_response(model_settings) + assert response == model_settings.safe_text + + +def test_get_response_unsafe(model_settings: ModelSettings): + """Check we get the safe response with is_unsafe returns False""" + with patch( + "nemoguardrails.benchmark.mock_llm_server.response_data.is_unsafe" + ) as mock_is_unsafe: + mock_is_unsafe.return_value = True + response = get_response(model_settings) + assert response == model_settings.unsafe_text + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.normal") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.clip") +def test_get_latency_seconds_mocks_no_seed( + mock_clip, mock_normal, mock_seed, model_settings: ModelSettings +): + """Check we call the correct numpy functions (not including seed)""" + + mock_normal.return_value = model_settings.latency_mean_seconds + mock_clip.return_value = model_settings.latency_max_seconds + + result = get_latency_seconds(model_settings) + + assert result == mock_clip.return_value + assert mock_seed.call_count == 0 + mock_normal.assert_called_once_with( + loc=model_settings.latency_mean_seconds, + scale=model_settings.latency_std_seconds, + size=1, + ) + mock_clip.assert_called_once_with( + mock_normal.return_value, + a_min=model_settings.latency_min_seconds, + a_max=model_settings.latency_max_seconds, + ) + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.normal") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.clip") +def test_get_latency_seconds_mocks_with_seed( + mock_clip, mock_normal, mock_seed, model_settings: ModelSettings, random_seed: int +): + """Check we call the correct numpy functions (not including seed)""" + + mock_normal.return_value = model_settings.latency_mean_seconds + mock_clip.return_value = model_settings.latency_max_seconds + + result = get_latency_seconds(model_settings, seed=random_seed) + + assert result == mock_clip.return_value + mock_seed.assert_called_once_with(random_seed) + mock_normal.assert_called_once_with( + loc=model_settings.latency_mean_seconds, + scale=model_settings.latency_std_seconds, + size=1, + ) + mock_clip.assert_called_once_with( + mock_normal.return_value, + a_min=model_settings.latency_min_seconds, + a_max=model_settings.latency_max_seconds, + ) + + +# +# class TestGetResponse: +# """Test the get_response function.""" +# +# def test_get_response_safe(self, model_settings): +# """Test getting safe response when not unsafe.""" +# +# # P(Unsafe) = 0, so all responses will be safe +# model_settings.unsafe_probability = 0.0 +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# +# response = get_response(model_settings) +# assert response == model_settings.safe_text +# assert mock_seed.call_count == 0 +# +# def test_get_response_unsafe(self, model_settings): +# """Test getting safe response when not unsafe.""" +# +# # P(Unsafe) = 1, so all responses will be unsafe +# model_settings.unsafe_probability = 1.0 +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# +# response = get_response(model_settings) +# assert response == model_settings.unsafe_text +# assert mock_seed.call_count == 0 +# +# def test_get_response_with_seed(self, model_settings, random_seed): +# """Test that a seed is passed onto np.random.seed""" +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# response = get_response(model_settings, seed=random_seed) +# +# assert mock_seed.call_count == 1 +# assert mock_seed.called_once_with(random_seed) +# +# +# class TestGetLatencySeconds: +# """Test the get_latency_seconds function.""" +# +# def setup_method(self): +# """Set up test configuration before each test.""" +# config_data = { +# "model": "test-model", +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# "latency_min_seconds": 0.1, +# "latency_max_seconds": 2.0, +# "latency_mean_seconds": 0.5, +# "latency_std_seconds": 0.2, +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# self.temp_file = f.name +# load_config(self.temp_file) +# +# def teardown_method(self): +# """Clean up after each test.""" +# import os +# +# os.unlink(self.temp_file) +# +# def test_get_latency_seconds_in_bounds(self): +# """Test that latency is within configured bounds.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latency = get_latency_seconds(config, seed=42) +# assert config.latency_min_seconds <= latency <= config.latency_max_seconds +# assert isinstance(latency, float) +# +# def test_get_latency_seconds_with_seed_deterministic(self): +# """Test that same seed produces same latency.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latency1 = get_latency_seconds(config, seed=12345) +# latency2 = get_latency_seconds(config, seed=12345) +# assert latency1 == latency2 +# +# def test_get_latency_seconds_without_seed_random(self): +# """Test that without seed, latencies vary.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latencies = [get_latency_seconds(config) for _ in range(20)] +# # Should have some variation (not all the same) +# assert len(set(latencies)) > 1 +# +# def test_get_latency_seconds_clipping_min(self): +# """Test that latency is clipped to minimum.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Try many samples to potentially get one that would be below min +# latencies = [get_latency_seconds(config, seed=i) for i in range(100)] +# assert all(lat >= config.latency_min_seconds for lat in latencies) +# +# def test_get_latency_seconds_clipping_max(self): +# """Test that latency is clipped to maximum.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Try many samples to potentially get one that would be above max +# latencies = [get_latency_seconds(config, seed=i) for i in range(100)] +# assert all(lat <= config.latency_max_seconds for lat in latencies) +# +# def test_get_latency_seconds_distribution_mean(self): +# """Test that latency follows expected distribution.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Generate many samples and check mean is approximately correct +# np.random.seed(42) +# latencies = [get_latency_seconds(config) for _ in range(1000)] +# mean_latency = np.mean(latencies) +# +# # Mean should be close to configured mean (allowing for clipping) +# # With clipping, mean will be between min and max +# assert config.latency_min_seconds <= mean_latency <= config.latency_max_seconds +# +# +# class TestIsUnsafe: +# """Test the is_unsafe function.""" +# +# def setup_method(self): +# """Set up test configuration before each test.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 0.3, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# self.temp_file = f.name +# load_config(self.temp_file) +# +# def teardown_method(self): +# """Clean up after each test.""" +# import os +# +# os.unlink(self.temp_file) +# +# def test_is_unsafe_returns_bool(self): +# """Test that is_unsafe returns a boolean.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# result = is_unsafe(config, seed=42) +# assert isinstance(result, bool) +# +# def test_is_unsafe_with_seed_deterministic(self): +# """Test that same seed produces same result.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# result1 = is_unsafe(config, seed=12345) +# result2 = is_unsafe(config, seed=12345) +# assert result1 == result2 +# +# def test_is_unsafe_without_seed_random(self): +# """Test that without seed, results vary.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# results = [is_unsafe(config) for _ in range(50)] +# # Should have both True and False (with high probability) +# assert True in results or False in results +# +# def test_is_unsafe_probability_distribution(self): +# """Test that unsafe probability follows configured distribution.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Generate many samples and check probability +# np.random.seed(42) +# results = [is_unsafe(config) for _ in range(1000)] +# unsafe_rate = sum(results) / len(results) +# +# # Should be approximately 0.3 (allowing for randomness) +# assert 0.2 <= unsafe_rate <= 0.4 +# +# def test_is_unsafe_zero_probability(self): +# """Test with zero unsafe probability.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 0.0, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# temp_file = f.name +# +# try: +# load_config(temp_file) +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Should always be safe +# results = [is_unsafe(config) for _ in range(20)] +# assert all(not result for result in results) +# finally: +# import os +# +# os.unlink(temp_file) +# +# def test_is_unsafe_one_probability(self): +# """Test with 100% unsafe probability.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 1.0, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# temp_file = f.name +# +# try: +# load_config(temp_file) +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Should always be unsafe +# results = [is_unsafe(config) for _ in range(20)] +# assert all(result for result in results) +# finally: +# import os +# +# os.unlink(temp_file) From 4104a1f6b00da05bdcdbd7a84d462e728c47779e Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:20:31 -0500 Subject: [PATCH 13/15] Check for config file --- nemoguardrails/benchmark/mock_llm_server/run_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index cd1bfe789..6485d8159 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -86,6 +86,9 @@ def main(): "No CONFIG_FILE environment variable set, or --config-file CLI argument" ) + if not (os.path.exists(config_file) and os.path.isfile(config_file)): + raise RuntimeError(f"Can't open {config_file}") + log.info("Using config file: %s", config_file) os.environ[CONFIG_FILE_ENV_VAR] = config_file From 1cca2ff67bea65bdbe5ca315d63aea5830b9a5d8 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:27:37 -0500 Subject: [PATCH 14/15] Rename test files to avoid conflicts with other tests --- tests/benchmark/{test_api.py => test_mock_api.py} | 0 tests/benchmark/{test_config.py => test_mock_config.py} | 0 tests/benchmark/{test_models.py => test_mock_models.py} | 0 .../{test_response_data.py => test_mock_response_data.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/benchmark/{test_api.py => test_mock_api.py} (100%) rename tests/benchmark/{test_config.py => test_mock_config.py} (100%) rename tests/benchmark/{test_models.py => test_mock_models.py} (100%) rename tests/benchmark/{test_response_data.py => test_mock_response_data.py} (100%) diff --git a/tests/benchmark/test_api.py b/tests/benchmark/test_mock_api.py similarity index 100% rename from tests/benchmark/test_api.py rename to tests/benchmark/test_mock_api.py diff --git a/tests/benchmark/test_config.py b/tests/benchmark/test_mock_config.py similarity index 100% rename from tests/benchmark/test_config.py rename to tests/benchmark/test_mock_config.py diff --git a/tests/benchmark/test_models.py b/tests/benchmark/test_mock_models.py similarity index 100% rename from tests/benchmark/test_models.py rename to tests/benchmark/test_mock_models.py diff --git a/tests/benchmark/test_response_data.py b/tests/benchmark/test_mock_response_data.py similarity index 100% rename from tests/benchmark/test_response_data.py rename to tests/benchmark/test_mock_response_data.py From e87715c07be84720d8d13f24713ebdd2452fb119 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:51:08 -0500 Subject: [PATCH 15/15] Remove example_usage.py script and type-clean config.py --- .../benchmark/mock_llm_server/config.py | 2 +- .../mock_llm_server/example_usage.py | 206 ------------------ 2 files changed, 1 insertion(+), 207 deletions(-) delete mode 100644 nemoguardrails/benchmark/mock_llm_server/example_usage.py diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 2fcd0166e..c2b9b0d6e 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -63,6 +63,6 @@ class ModelSettings(BaseSettings): def get_settings() -> ModelSettings: """Singleton-pattern to get settings once via lru_cache""" - settings = ModelSettings() + settings = ModelSettings() # type: ignore (These are filled in by loading from CONFIG_FILE) print("Returning ModelSettings: %s", settings) return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/example_usage.py b/nemoguardrails/benchmark/mock_llm_server/example_usage.py deleted file mode 100644 index 278ab8d94..000000000 --- a/nemoguardrails/benchmark/mock_llm_server/example_usage.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Example usage of the Mock LLM Server. - -This script demonstrates how to interact with the running mock server -using standard HTTP requests and the OpenAI Python client. -""" - -import json -import time - -import requests - - -def test_with_requests(): - """Test the server using the requests library.""" - base_url = "http://localhost:8000" - - print("Testing Mock LLM Server with requests library...") - print("=" * 50) - - # Test health endpoint - try: - response = requests.get(f"{base_url}/health", timeout=5) - print(f"Health check: {response.status_code} - {response.json()}") - except requests.RequestException as e: - print(f"Health check failed: {e}") - print("Make sure the server is running: python run_server.py") - return - - # Test models endpoint - try: - response = requests.get(f"{base_url}/v1/models", timeout=5) - print(f"\\nModels: {response.status_code}") - models_data = response.json() - for model in models_data["data"]: - print(f" - {model['id']}") - except requests.RequestException as e: - print(f"Models request failed: {e}") - - # Test chat completion - try: - chat_payload = { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "max_tokens": 50, - } - response = requests.post( - f"{base_url}/v1/chat/completions", - json=chat_payload, - headers={"Content-Type": "application/json"}, - timeout=5, - ) - print(f"\\nChat completion: {response.status_code}") - if response.status_code == 200: - data = response.json() - print(f"Response: {data['choices'][0]['message']['content']}") - print(f"Usage: {data['usage']}") - except requests.RequestException as e: - print(f"Chat completion failed: {e}") - - # Test text completion - try: - completion_payload = { - "model": "text-davinci-003", - "prompt": "The capital of France is", - "max_tokens": 10, - } - response = requests.post( - f"{base_url}/v1/completions", - json=completion_payload, - headers={"Content-Type": "application/json"}, - timeout=5, - ) - print(f"\\nText completion: {response.status_code}") - if response.status_code == 200: - data = response.json() - print(f"Response: {data['choices'][0]['text']}") - print(f"Usage: {data['usage']}") - except requests.RequestException as e: - print(f"Text completion failed: {e}") - - -def test_with_openai_client(): - """Test the server using the OpenAI Python client.""" - try: - import openai - except ImportError: - print("\\nOpenAI client not available. Install with: pip install openai") - return - - print("\\n" + "=" * 50) - print("Testing with OpenAI client library...") - print("=" * 50) - - # Configure client to use local server - client = openai.OpenAI( - base_url="http://localhost:8000/v1", - api_key="dummy-key", # Server doesn't validate, but client requires it - ) - - try: - # Test chat completion - response = client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Hello from OpenAI client!"}], - ) - print(f"Chat completion response: {response.choices[0].message.content}") - print( - f"Usage: prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}" - ) - - # Test text completion (if supported by client version) - try: - response = client.completions.create( - model="text-davinci-003", prompt="OpenAI client test: ", max_tokens=10 - ) - print(f"Text completion response: {response.choices[0].text}") - except Exception as e: - print(f"Text completion not supported in this OpenAI client version: {e}") - - except Exception as e: - print(f"OpenAI client test failed: {e}") - - -def benchmark_performance(): - """Simple performance benchmark.""" - print("\\n" + "=" * 50) - print("Performance Benchmark") - print("=" * 50) - - base_url = "http://localhost:8000" - num_requests = 10 - - chat_payload = { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Benchmark test"}], - "max_tokens": 20, - } - - print(f"Making {num_requests} chat completion requests...") - - start_time = time.time() - successful_requests = 0 - - for i in range(num_requests): - try: - response = requests.post( - f"{base_url}/v1/chat/completions", - json=chat_payload, - headers={"Content-Type": "application/json"}, - timeout=5, - ) - if response.status_code == 200: - successful_requests += 1 - except requests.RequestException: - pass - - end_time = time.time() - total_time = end_time - start_time - - print(f"Results:") - print(f" Total requests: {num_requests}") - print(f" Successful requests: {successful_requests}") - print(f" Total time: {total_time:.2f} seconds") - print(f" Average time per request: {total_time/num_requests:.3f} seconds") - print(f" Requests per second: {num_requests/total_time:.2f}") - - -def main(): - """Main function to run all tests.""" - print("Mock LLM Server Example Usage") - print("=" * 50) - print("Make sure the server is running before running this script:") - print(" python run_server.py") - print() - - # Test with requests - test_with_requests() - - # Test with OpenAI client - test_with_openai_client() - - # Simple benchmark - benchmark_performance() - - print("\\nExample completed!") - - -if __name__ == "__main__": - main()