Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions src/llama_stack/providers/remote/safety/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,43 +125,57 @@ async def _guardrails_post(self, path: str, data: Any | None):

async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Queries the /v1/chat/completions endpoint of the NeMo guardrails deployed API.

Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.

Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
RunShieldResponse: Response with SafetyViolation if content is blocked, None otherwise.

Raises:
requests.HTTPError: If the POST request fails.
"""
request_data = {
"model": self.model,
"config_id": self.config_id,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)

if response["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response["rails_status"]
response = await self._guardrails_post(path="/v1/chat/completions", data=request_data)

# Support legacy format with explicit status field
if "status" in response and response["status"] == "blocked":
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
user_message="Sorry I cannot do this.",
violation_level=ViolationLevel.ERROR,
metadata=metadata,
metadata=response.get("rails_status", {}),
)
)

return RunShieldResponse(violation=None)
# NOTE: The implementation targets the actual behavior of the NeMo Guardrails server
# as defined in 'nemoguardrails/server/api.py'. The 'RequestBody' class accepts
# 'config_id' at the top level, and 'ResponseBody' returns a 'messages' array,
# distinct from the OpenAI 'choices' format often referenced in documentation.

response_messages = response.get("messages", [])
if response_messages:
content = response_messages[0].get("content", "").strip()
else:
choices = response.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "").strip()
else:
content = ""

refusal_phrases = ["sorry i cannot do this", "i cannot help with that", "i can't assist with that"]
is_blocked = not content or any(phrase in content.lower() for phrase in refusal_phrases)

return RunShieldResponse(
violation=SafetyViolation(
user_message="Sorry I cannot do this.",
violation_level=ViolationLevel.ERROR,
metadata={"reason": "Content violates safety guidelines", "response": content or "(empty)"},
)
if is_blocked
else None
)
39 changes: 6 additions & 33 deletions tests/unit/providers/nvidia/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,13 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):

# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
path="/v1/chat/completions",
data={
"model": shield_id,
"config_id": "self-check",
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
},
)

Expand Down Expand Up @@ -206,22 +197,13 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):

# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
path="/v1/chat/completions",
data={
"model": shield_id,
"config_id": "self-check",
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
},
)

Expand Down Expand Up @@ -286,22 +268,13 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):

# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
path="/v1/chat/completions",
data={
"model": shield_id,
"config_id": "self-check",
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
},
)
# Verify the exception message
Expand Down