Skip to content

Commit 4e9633f

Browse files
authored
feat: Make Safety API an optional dependency for meta-reference agents provider (#4169)
# What does this PR do? Change Safety API from required to optional dependency, following the established pattern used for other optional dependencies in Llama Stack. The provider now starts successfully without Safety API configured. Requests that explicitly include guardrails will receive a clear error message when Safety API is unavailable. This enables local development and testing without Safety API while maintaining clear error messages when guardrail features are requested. Closes #4165 Signed-off-by: Anik Bhattacharjee <[email protected]> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> 1. New unit tests added in `tests/unit/providers/agents/meta_reference/test_safety_optional.py` 2. Integration tests performed with the files in https://gist.github.com/anik120/c33cef497ec7085e1fe2164e0705b8d6 (i) test with `test_integration_no_safety_fail.yaml`: Config WITHOUT Safety API, should fail with helpful error since `required_safety_api` is `true` by default ``` $ uv run llama stack run test_integration_no_safety_fail.yaml 2>&1 | grep -B 5 -A 15 "ValueError.*Safety\|Safety API is required" File "/Users/anbhatta/go/src/github.com/llamastack/llama-stack/src/llama_stack/providers/inline/agents/meta_reference /__init__.py", line 27, in get_provider_impl raise ValueError( ...<9 lines>... ) ValueError: Safety API is required but not configured. To run without safety checks, explicitly set in your configuration: providers: agents: - provider_id: meta-reference provider_type: inline::meta-reference config: require_safety_api: false Warning: This disables all safety guardrails for this agents provider. ``` (ii) test with `test_integration_no_safety_works.yaml` Config WITHOUT Safety API, **but** `require_safety_api=false` is explicitly set, should succeed ``` $ uv run llama stack run test_integration_no_safety_works.yaml INFO 2025-11-16 09:49:10,044 llama_stack.cli.stack.run:169 cli: Using run configuration: /Users/anbhatta/go/src/github.com/llamastack/llama-stack/test_integration_no_safety_works.yaml INFO 2025-11-16 09:49:10,052 llama_stack.cli.stack.run:228 cli: HTTPS enabled with certificates: Key: None Cert: None . . . INFO 2025-11-16 09:49:38,528 llama_stack.core.stack:495 core: starting registry refresh task INFO 2025-11-16 09:49:38,534 uvicorn.error:62 uncategorized: Application startup complete. INFO 2025-11-16 09:49:38,535 uvicorn.error:216 uncategorized: Uvicorn running on http://0.0.0.0:8321 (Press CTRL+C ``` Signed-off-by: Anik Bhattacharjee <[email protected]> Signed-off-by: Anik Bhattacharjee <[email protected]>
1 parent d5cd0ee commit 4e9633f

File tree

7 files changed

+227
-6
lines changed

7 files changed

+227
-6
lines changed

src/llama_stack/providers/inline/agents/meta_reference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def get_provider_impl(
2323
config,
2424
deps[Api.inference],
2525
deps[Api.vector_io],
26-
deps[Api.safety],
26+
deps.get(Api.safety),
2727
deps[Api.tool_runtime],
2828
deps[Api.tool_groups],
2929
deps[Api.conversations],

src/llama_stack/providers/inline/agents/meta_reference/agents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
config: MetaReferenceAgentsImplConfig,
4242
inference_api: Inference,
4343
vector_io_api: VectorIO,
44-
safety_api: Safety,
44+
safety_api: Safety | None,
4545
tool_runtime_api: ToolRuntime,
4646
tool_groups_api: ToolGroups,
4747
conversations_api: Conversations,

src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
tool_runtime_api: ToolRuntime,
6868
responses_store: ResponsesStore,
6969
vector_io_api: VectorIO, # VectorIO
70-
safety_api: Safety,
70+
safety_api: Safety | None,
7171
conversations_api: Conversations,
7272
):
7373
self.inference_api = inference_api
@@ -273,6 +273,14 @@ async def create_openai_response(
273273

274274
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
275275

276+
# Validate that Safety API is available if guardrails are requested
277+
if guardrail_ids and self.safety_api is None:
278+
raise ValueError(
279+
"Cannot process guardrails: Safety API is not configured.\n\n"
280+
"To use guardrails, ensure the Safety API is configured in your stack, or remove "
281+
"the 'guardrails' parameter from your request."
282+
)
283+
276284
if conversation is not None:
277285
if previous_response_id is not None:
278286
raise ValueError(

src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
OpenAIResponseUsage,
6767
OpenAIResponseUsageInputTokensDetails,
6868
OpenAIResponseUsageOutputTokensDetails,
69+
Safety,
6970
WebSearchToolTypes,
7071
)
7172

@@ -111,7 +112,7 @@ def __init__(
111112
max_infer_iters: int,
112113
tool_executor, # Will be the tool execution logic from the main class
113114
instructions: str | None,
114-
safety_api,
115+
safety_api: Safety | None,
115116
guardrail_ids: list[str] | None = None,
116117
prompt: OpenAIResponsePrompt | None = None,
117118
parallel_tool_calls: bool | None = None,

src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,15 @@ def is_function_tool_call(
320320
return False
321321

322322

323-
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
323+
async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
324324
"""Run guardrails against messages and return violation message if blocked."""
325325
if not messages:
326326
return None
327327

328+
# If safety API is not available, skip guardrails
329+
if safety_api is None:
330+
return None
331+
328332
# Look up shields to get their provider_resource_id (actual model ID)
329333
model_ids = []
330334
# TODO: list_shields not in Safety interface but available at runtime via API routing

src/llama_stack/providers/registry/agents.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ def available_providers() -> list[ProviderSpec]:
3030
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
3131
api_dependencies=[
3232
Api.inference,
33-
Api.safety,
3433
Api.vector_io,
3534
Api.tool_runtime,
3635
Api.tool_groups,
3736
Api.conversations,
3837
],
38+
optional_api_dependencies=[
39+
Api.safety,
40+
],
3941
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
4042
),
4143
]
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
"""Tests for making Safety API optional in meta-reference agents provider.
8+
9+
This test suite validates the changes introduced to fix issue #4165, which
10+
allows running the meta-reference agents provider without the Safety API.
11+
Safety API is now an optional dependency, and errors are raised at request time
12+
when guardrails are explicitly requested without Safety API configured.
13+
"""
14+
15+
from unittest.mock import AsyncMock, MagicMock, patch
16+
17+
import pytest
18+
19+
from llama_stack.core.datatypes import Api
20+
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
21+
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
22+
from llama_stack.providers.inline.agents.meta_reference.config import (
23+
AgentPersistenceConfig,
24+
MetaReferenceAgentsImplConfig,
25+
)
26+
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
27+
run_guardrails,
28+
)
29+
30+
31+
@pytest.fixture
32+
def mock_persistence_config():
33+
"""Create a mock persistence configuration."""
34+
return AgentPersistenceConfig(
35+
agent_state=KVStoreReference(
36+
backend="kv_default",
37+
namespace="agents",
38+
),
39+
responses=ResponsesStoreReference(
40+
backend="sql_default",
41+
table_name="responses",
42+
),
43+
)
44+
45+
46+
@pytest.fixture
47+
def mock_deps():
48+
"""Create mock dependencies for the agents provider."""
49+
# Create mock APIs
50+
inference_api = AsyncMock()
51+
vector_io_api = AsyncMock()
52+
tool_runtime_api = AsyncMock()
53+
tool_groups_api = AsyncMock()
54+
conversations_api = AsyncMock()
55+
56+
return {
57+
Api.inference: inference_api,
58+
Api.vector_io: vector_io_api,
59+
Api.tool_runtime: tool_runtime_api,
60+
Api.tool_groups: tool_groups_api,
61+
Api.conversations: conversations_api,
62+
}
63+
64+
65+
class TestProviderInitialization:
66+
"""Test provider initialization with different safety API configurations."""
67+
68+
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
69+
"""Test successful initialization when Safety API is configured."""
70+
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
71+
72+
# Add safety API to deps
73+
safety_api = AsyncMock()
74+
mock_deps[Api.safety] = safety_api
75+
76+
# Mock the initialize method to avoid actual initialization
77+
with patch(
78+
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
79+
new_callable=AsyncMock,
80+
):
81+
# Should not raise any exception
82+
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
83+
assert provider is not None
84+
85+
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
86+
"""Test successful initialization when Safety API is not configured."""
87+
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
88+
89+
# Safety API is NOT in mock_deps - provider should still start
90+
# Mock the initialize method to avoid actual initialization
91+
with patch(
92+
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
93+
new_callable=AsyncMock,
94+
):
95+
# Should not raise any exception
96+
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
97+
assert provider is not None
98+
assert provider.safety_api is None
99+
100+
101+
class TestGuardrailsFunctionality:
102+
"""Test run_guardrails function with optional safety API."""
103+
104+
async def test_run_guardrails_with_none_safety_api(self):
105+
"""Test that run_guardrails returns None when safety_api is None."""
106+
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
107+
assert result is None
108+
109+
async def test_run_guardrails_with_empty_messages(self):
110+
"""Test that run_guardrails returns None for empty messages."""
111+
# Test with None safety API
112+
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
113+
assert result is None
114+
115+
# Test with mock safety API
116+
mock_safety_api = AsyncMock()
117+
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
118+
assert result is None
119+
120+
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
121+
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
122+
# Should not raise exception, just return None
123+
result = await run_guardrails(
124+
safety_api=None,
125+
messages="potentially harmful content",
126+
guardrail_ids=["llama-guard", "content-filter"],
127+
)
128+
assert result is None
129+
130+
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
131+
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
132+
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
133+
OpenAIResponsesImpl,
134+
)
135+
from llama_stack_api import ResponseGuardrailSpec
136+
137+
# Create OpenAIResponsesImpl with no safety API
138+
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
139+
impl = OpenAIResponsesImpl(
140+
inference_api=mock_deps[Api.inference],
141+
tool_groups_api=mock_deps[Api.tool_groups],
142+
tool_runtime_api=mock_deps[Api.tool_runtime],
143+
responses_store=MagicMock(),
144+
vector_io_api=mock_deps[Api.vector_io],
145+
safety_api=None, # No Safety API
146+
conversations_api=mock_deps[Api.conversations],
147+
)
148+
149+
# Test with string guardrail
150+
with pytest.raises(ValueError) as exc_info:
151+
await impl.create_openai_response(
152+
input="test input",
153+
model="test-model",
154+
guardrails=["llama-guard"],
155+
)
156+
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
157+
158+
# Test with ResponseGuardrailSpec
159+
with pytest.raises(ValueError) as exc_info:
160+
await impl.create_openai_response(
161+
input="test input",
162+
model="test-model",
163+
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
164+
)
165+
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
166+
167+
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
168+
self, mock_persistence_config, mock_deps
169+
):
170+
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
171+
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
172+
OpenAIResponsesImpl,
173+
)
174+
175+
# Create OpenAIResponsesImpl with no safety API
176+
with (
177+
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
178+
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
179+
):
180+
# Mock the streaming response to return a simple async generator
181+
async def mock_generator():
182+
yield MagicMock()
183+
184+
mock_stream.return_value = mock_generator()
185+
186+
impl = OpenAIResponsesImpl(
187+
inference_api=mock_deps[Api.inference],
188+
tool_groups_api=mock_deps[Api.tool_groups],
189+
tool_runtime_api=mock_deps[Api.tool_runtime],
190+
responses_store=MagicMock(),
191+
vector_io_api=mock_deps[Api.vector_io],
192+
safety_api=None, # No Safety API
193+
conversations_api=mock_deps[Api.conversations],
194+
)
195+
196+
# Should not raise when no guardrails requested
197+
# Note: This will still fail later in execution due to mocking, but should pass the validation
198+
try:
199+
await impl.create_openai_response(
200+
input="test input",
201+
model="test-model",
202+
guardrails=None, # No guardrails
203+
)
204+
except Exception as e:
205+
# Ensure the error is NOT about missing Safety API
206+
assert "Cannot process guardrails: Safety API is not configured" not in str(e)

0 commit comments

Comments
 (0)