Skip to content

Commit 8fbf836

Browse files
yahya-moumanYun-KimKyle-Verhoog
authored
feat(LLMObs): update the prompt annotation (#12551)
This PR updates the prompt typed dict extending it with a couple fields : - name : name of the prompt - chat_template : list of role,content pairs where content is a string template of a prompt - tags : list of tags for the prompt run. It also adds the strict validation mode. Strict validation adds the following checks : - id is mandatory - either a template or a chat_template should be provided Co-authored-by: Yun Kim <[email protected]> Co-authored-by: kyle <[email protected]>
1 parent 9c89921 commit 8fbf836

File tree

8 files changed

+198
-68
lines changed

8 files changed

+198
-68
lines changed

ddtrace/llmobs/_constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@
8686

8787
SPAN_LINKS = "_ml_obs.span_links"
8888
NAME = "_ml_obs.name"
89+
90+
# Prompt constants
91+
DEFAULT_PROMPT_NAME = "unnamed-prompt"
92+
8993
DECORATOR = "_ml_obs.decorator"
9094
INTEGRATION = "_ml_obs.integration"
9195

ddtrace/llmobs/_integrations/langchain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from ddtrace.llmobs._integrations.utils import update_proxy_workflow_input_output_value
4141
from ddtrace.llmobs._utils import _get_attr
4242
from ddtrace.llmobs._utils import _get_nearest_llmobs_ancestor
43+
from ddtrace.llmobs._utils import _validate_prompt
4344
from ddtrace.llmobs._utils import safe_json
44-
from ddtrace.llmobs._utils import validate_prompt
4545
from ddtrace.llmobs.utils import Document
4646
from ddtrace.trace import Span
4747

@@ -885,7 +885,7 @@ def llmobs_set_prompt_tag(self, instance, span: Span):
885885
if prompt_value_meta is not None:
886886
prompt = prompt_value_meta
887887
try:
888-
prompt = validate_prompt(prompt)
888+
prompt = _validate_prompt(prompt, strict_validation=True)
889889
span._set_ctx_item(INPUT_PROMPT, prompt)
890890
except Exception as e:
891891
log.debug("Failed to validate langchain prompt", e)

ddtrace/llmobs/_llmobs.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@
100100
from ddtrace.llmobs._utils import _get_session_id
101101
from ddtrace.llmobs._utils import _get_span_name
102102
from ddtrace.llmobs._utils import _is_evaluation_span
103+
from ddtrace.llmobs._utils import _validate_prompt
103104
from ddtrace.llmobs._utils import enforce_message_role
104105
from ddtrace.llmobs._utils import safe_json
105-
from ddtrace.llmobs._utils import validate_prompt
106106
from ddtrace.llmobs._writer import LLMObsEvalMetricWriter
107107
from ddtrace.llmobs._writer import LLMObsEvaluationMetricEvent
108108
from ddtrace.llmobs._writer import LLMObsExperimentsClient
@@ -112,6 +112,7 @@
112112
from ddtrace.llmobs.utils import Documents
113113
from ddtrace.llmobs.utils import ExportedLLMObsSpan
114114
from ddtrace.llmobs.utils import Messages
115+
from ddtrace.llmobs.utils import Prompt
115116
from ddtrace.llmobs.utils import extract_tool_definitions
116117
from ddtrace.propagation.http import HTTPPropagator
117118

@@ -841,7 +842,10 @@ def _tag_span_links(self, span, span_links):
841842

842843
@classmethod
843844
def annotation_context(
844-
cls, tags: Optional[Dict[str, Any]] = None, prompt: Optional[dict] = None, name: Optional[str] = None
845+
cls,
846+
tags: Optional[Dict[str, Any]] = None,
847+
prompt: Optional[Union[dict, Prompt]] = None,
848+
name: Optional[str] = None,
845849
) -> AnnotationContext:
846850
"""
847851
Sets specified attributes on all LLMObs spans created while the returned AnnotationContext is active.
@@ -850,10 +854,16 @@ def annotation_context(
850854
:param tags: Dictionary of JSON serializable key-value tag pairs to set or update on the LLMObs span
851855
regarding the span's context.
852856
:param prompt: A dictionary that represents the prompt used for an LLM call in the following form:
853-
`{"template": "...", "id": "...", "version": "...", "variables": {"variable_1": "...", ...}}`.
857+
`{
858+
"id": "...",
859+
"version": "...",
860+
"chat_template": [{"content": "...", "role": "..."}, ...],
861+
"variables": {"variable_1": "...", ...}}`.
862+
"tags": {"key1": "value1", "key2": "value2"},
863+
}`
854864
Can also be set using the `ddtrace.llmobs.utils.Prompt` constructor class.
855865
- This argument is only applicable to LLM spans.
856-
- The dictionary may contain two optional keys relevant to RAG applications:
866+
- The dictionary may contain optional keys relevant to Templates and RAG applications:
857867
`rag_context_variables` - a list of variable key names that contain ground
858868
truth context information
859869
`rag_query_variables` - a list of variable key names that contains query
@@ -1289,7 +1299,14 @@ def annotate(
12891299
:param Span span: Span to annotate. If no span is provided, the current active span will be used.
12901300
Must be an LLMObs-type span, i.e. generated by the LLMObs SDK.
12911301
:param prompt: A dictionary that represents the prompt used for an LLM call in the following form:
1292-
`{"template": "...", "id": "...", "version": "...", "variables": {"variable_1": "...", ...}}`.
1302+
`{
1303+
"id": "...",
1304+
"template": "...",
1305+
"chat_template": [{"content": "...", "role": "..."}, ...])
1306+
"version": "...",
1307+
"variables": {"variable_1": "...", ...},
1308+
tags": {"tag_1": "...", ...},
1309+
}`.
12931310
Can also be set using the `ddtrace.llmobs.utils.Prompt` constructor class.
12941311
- This argument is only applicable to LLM spans.
12951312
- The dictionary may contain two optional keys relevant to RAG applications:
@@ -1373,11 +1390,11 @@ def annotate(
13731390
span.name = _name
13741391
if prompt is not None:
13751392
try:
1376-
validated_prompt = validate_prompt(prompt)
1393+
validated_prompt = _validate_prompt(prompt, strict_validation=False)
13771394
cls._set_dict_attribute(span, INPUT_PROMPT, validated_prompt)
1378-
except TypeError:
1395+
except (ValueError, TypeError) as e:
13791396
error = "invalid_prompt"
1380-
log.warning("Failed to validate prompt with error: ", exc_info=True)
1397+
log.warning("Failed to validate prompt with error:", str(e), exc_info=True)
13811398
if not span_kind:
13821399
log.debug("Span kind not specified, skipping annotation for input/output data")
13831400
return

ddtrace/llmobs/_utils.py

Lines changed: 92 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
from dataclasses import is_dataclass
44
import json
5+
from typing import Any
56
from typing import Dict
67
from typing import List
78
from typing import Optional
@@ -14,6 +15,7 @@
1415
from ddtrace.internal.logger import get_logger
1516
from ddtrace.internal.utils.formats import format_trace_id
1617
from ddtrace.llmobs._constants import CREWAI_APM_SPAN_NAME
18+
from ddtrace.llmobs._constants import DEFAULT_PROMPT_NAME
1719
from ddtrace.llmobs._constants import GEMINI_APM_SPAN_NAME
1820
from ddtrace.llmobs._constants import INTERNAL_CONTEXT_VARIABLE_KEYS
1921
from ddtrace.llmobs._constants import INTERNAL_QUERY_VARIABLE_KEYS
@@ -27,11 +29,14 @@
2729
from ddtrace.llmobs._constants import SESSION_ID
2830
from ddtrace.llmobs._constants import SPAN_LINKS
2931
from ddtrace.llmobs._constants import VERTEXAI_APM_SPAN_NAME
32+
from ddtrace.llmobs.utils import Message
33+
from ddtrace.llmobs.utils import Prompt
3034
from ddtrace.trace import Span
3135

3236

3337
log = get_logger(__name__)
3438

39+
ValidatedPromptDict = Dict[str, Union[str, Dict[str, Any], List[str], List[Dict[str, str]], List[Message]]]
3540

3641
STANDARD_INTEGRATION_SPAN_NAMES = (
3742
CREWAI_APM_SPAN_NAME,
@@ -42,50 +47,100 @@
4247
)
4348

4449

45-
def validate_prompt(prompt: dict) -> Dict[str, Union[str, dict, List[str]]]:
46-
validated_prompt = {} # type: Dict[str, Union[str, dict, List[str]]]
50+
def _validate_prompt(prompt: Union[Dict[str, Any], Prompt], strict_validation: bool) -> ValidatedPromptDict:
4751
if not isinstance(prompt, dict):
48-
raise TypeError("Prompt must be a dictionary")
52+
raise TypeError(f"Prompt must be a dictionary, received {type(prompt).__name__}.")
53+
54+
ml_app = config._llmobs_ml_app
55+
prompt_id = prompt.get("id")
56+
version = prompt.get("version")
57+
tags = prompt.get("tags")
4958
variables = prompt.get("variables")
5059
template = prompt.get("template")
51-
version = prompt.get("version")
52-
prompt_id = prompt.get("id")
60+
chat_template = prompt.get("chat_template")
5361
ctx_variable_keys = prompt.get("rag_context_variables")
54-
rag_query_variable_keys = prompt.get("rag_query_variables")
55-
if variables is not None:
62+
query_variable_keys = prompt.get("rag_query_variables")
63+
64+
if strict_validation:
65+
if prompt_id is None:
66+
raise ValueError("'id' must be provided")
67+
if template is None and chat_template is None:
68+
raise ValueError("One of 'template' or 'chat_template' must be provided to annotate a prompt.")
69+
70+
if template and chat_template:
71+
raise ValueError("Only one of 'template' or 'chat_template' can be provided, not both.")
72+
73+
final_prompt_id = prompt_id or f"{ml_app}_{DEFAULT_PROMPT_NAME}"
74+
final_ctx_variable_keys = ctx_variable_keys or ["context"]
75+
final_query_variable_keys = query_variable_keys or ["question"]
76+
77+
if not isinstance(final_prompt_id, str):
78+
raise TypeError(f"prompt_id {final_prompt_id} must be a string, received {type(final_prompt_id).__name__}")
79+
80+
if not (isinstance(final_ctx_variable_keys, list) and all(isinstance(i, str) for i in final_ctx_variable_keys)):
81+
raise TypeError(f"ctx_variables must be a list of strings, received {type(final_ctx_variable_keys).__name__}")
82+
83+
if not (isinstance(final_query_variable_keys, list) and all(isinstance(i, str) for i in final_query_variable_keys)):
84+
raise TypeError(
85+
f"query_variables must be a list of strings, received {type(final_query_variable_keys).__name__}"
86+
)
87+
88+
if version and not isinstance(version, str):
89+
raise TypeError(f"version: {version} must be a string, received {type(version).__name__}")
90+
91+
if tags:
92+
if not isinstance(tags, dict):
93+
raise TypeError(
94+
f"tags: {tags} must be a dictionary of string key-value pairs, received {type(tags).__name__}"
95+
)
96+
if not all(isinstance(k, str) for k in tags):
97+
raise TypeError("Keys of 'tags' must all be strings.")
98+
if not all(isinstance(k, str) for k in tags.values()):
99+
raise TypeError("Values of 'tags' must all be strings.")
100+
101+
if template and not isinstance(template, str):
102+
raise TypeError(f"template: {template} must be a string, received {type(template).__name__}")
103+
104+
if chat_template:
105+
if not isinstance(chat_template, list):
106+
raise TypeError("chat_template must be a list of dictionaries with string-string key value pairs.")
107+
for ct in chat_template:
108+
if not (isinstance(ct, dict) and all(k in ct for k in ("role", "content"))):
109+
raise TypeError(
110+
"Each 'chat_template' entry should be a string-string dictionary with role and content keys."
111+
)
112+
113+
if variables:
56114
if not isinstance(variables, dict):
57-
raise TypeError("Prompt variables must be a dictionary.")
58-
if not any(isinstance(k, str) or isinstance(v, str) for k, v in variables.items()):
59-
raise TypeError("Prompt variable keys and values must be strings.")
115+
raise TypeError(
116+
f"variables: {variables} must be a dictionary with string keys, received {type(variables).__name__}"
117+
)
118+
if not all(isinstance(k, str) for k in variables):
119+
raise TypeError("Keys of 'variables' must all be strings.")
120+
121+
final_chat_template = []
122+
if chat_template:
123+
for msg in chat_template:
124+
final_chat_template.append(Message(role=msg["role"], content=msg["content"]))
125+
126+
validated_prompt: ValidatedPromptDict = {}
127+
if final_prompt_id:
128+
validated_prompt["id"] = final_prompt_id
129+
if version:
130+
validated_prompt["version"] = version
131+
if variables:
60132
validated_prompt["variables"] = variables
61-
if template is not None:
62-
if not isinstance(template, str):
63-
raise TypeError("Prompt template must be a string")
133+
if template:
64134
validated_prompt["template"] = template
65-
if version is not None:
66-
if not isinstance(version, str):
67-
raise TypeError("Prompt version must be a string.")
68-
validated_prompt["version"] = version
69-
if prompt_id is not None:
70-
if not isinstance(prompt_id, str):
71-
raise TypeError("Prompt id must be a string.")
72-
validated_prompt["id"] = prompt_id
73-
if ctx_variable_keys is not None:
74-
if not isinstance(ctx_variable_keys, list):
75-
raise TypeError("Prompt field `context_variable_keys` must be a list of strings.")
76-
if not all(isinstance(k, str) for k in ctx_variable_keys):
77-
raise TypeError("Prompt field `context_variable_keys` must be a list of strings.")
78-
validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ctx_variable_keys
79-
else:
80-
validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ["context"]
81-
if rag_query_variable_keys is not None:
82-
if not isinstance(rag_query_variable_keys, list):
83-
raise TypeError("Prompt field `rag_query_variables` must be a list of strings.")
84-
if not all(isinstance(k, str) for k in rag_query_variable_keys):
85-
raise TypeError("Prompt field `rag_query_variables` must be a list of strings.")
86-
validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = rag_query_variable_keys
87-
else:
88-
validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = ["question"]
135+
if final_chat_template:
136+
validated_prompt["chat_template"] = final_chat_template
137+
if tags:
138+
validated_prompt["tags"] = tags
139+
if final_ctx_variable_keys:
140+
validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = final_ctx_variable_keys
141+
if final_query_variable_keys:
142+
validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = final_query_variable_keys
143+
89144
return validated_prompt
90145

91146

ddtrace/llmobs/utils.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,6 @@ def _extract_tool_result(tool_result: Dict[str, Any]) -> "ToolResult":
7575
{"content": str, "role": str, "tool_calls": List["ToolCall"], "tool_results": List["ToolResult"]},
7676
total=False,
7777
)
78-
Prompt = TypedDict(
79-
"Prompt",
80-
{
81-
"variables": Dict[str, str],
82-
"template": str,
83-
"id": str,
84-
"version": str,
85-
"rag_context_variables": List[
86-
str
87-
], # a list of variable key names that contain ground truth context information
88-
"rag_query_variables": List[str], # a list of variable key names that contains query information
89-
},
90-
total=False,
91-
)
9278
ToolCall = TypedDict(
9379
"ToolCall",
9480
{
@@ -163,6 +149,33 @@ def extract_tool_definitions(tool_definitions: List[Dict[str, Any]]) -> List[Too
163149
return validated_tool_definitions
164150

165151

152+
class Prompt(TypedDict, total=False):
153+
"""
154+
A Prompt object that contains the information needed to render a prompt.
155+
id: str - the id of the prompt set by the user. Should be unique per ml_app.
156+
version: str - user tag for the version of the prompt.
157+
variables: Dict[str, str] - a dictionary of variables that will be used to render the prompt
158+
chat_template: Optional[Union[List[Dict[str, str]], List[Message]]]
159+
- A list of dicts of (role,template)
160+
where role is the role of the prompt and template is the template string
161+
template: Optional[str]
162+
- It also accepts a string that represents the template for the prompt. Will default to "user" for a role
163+
tags: Optional[Dict[str, str]]
164+
- List of tags to add to the prompt run.
165+
rag_context_variables: List[str] - a list of variable key names that contain ground truth context information
166+
rag_query_variables: List[str] - a list of variable key names that contains query information
167+
"""
168+
169+
version: str
170+
id: str
171+
template: str
172+
chat_template: Union[List[Dict[str, str]], List[Message]]
173+
variables: Dict[str, str]
174+
tags: Dict[str, str]
175+
rag_context_variables: List[str]
176+
rag_query_variables: List[str]
177+
178+
166179
class Messages:
167180
def __init__(self, messages: Union[List[Dict[str, Any]], Dict[str, Any], str]):
168181
self.messages = []
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
LLM Observability: Extends the prompt structure to add ``tags`` and ``chat_template``.
5+
A new ``Prompt`` TypedDict class that would be used in annotation and annotation_context.

tests/llmobs/test_llmobs.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ddtrace.llmobs._constants import PARENT_ID_KEY
1313
from ddtrace.llmobs._constants import ROOT_PARENT_ID
1414
from ddtrace.llmobs._utils import _get_session_id
15+
from ddtrace.llmobs.utils import Prompt
1516
from tests.llmobs._utils import _expected_llmobs_llm_span_event
1617

1718

@@ -460,15 +461,42 @@ def test_structured_io_data(llmobs, llmobs_backend):
460461

461462
def test_structured_prompt_data(llmobs, llmobs_backend):
462463
with llmobs.llm() as span:
463-
llmobs.annotate(span, prompt={"template": "test {{value}}"})
464+
llmobs.annotate(span, input_data={"data": "test1"}, prompt={"template": "test {{value}}"})
465+
events = llmobs_backend.wait_for_num_events(num=1)
466+
assert len(events) == 1
467+
assert events[0][0]["spans"][0]["meta"]["input"]["prompt"] == {
468+
"id": "unnamed-ml-app_unnamed-prompt",
469+
"template": "test {{value}}",
470+
"_dd_context_variable_keys": ["context"],
471+
"_dd_query_variable_keys": ["question"],
472+
}
473+
474+
475+
def test_structured_prompt_data_v2(llmobs, llmobs_backend):
476+
prompt = Prompt(
477+
id="test",
478+
chat_template=[{"role": "user", "content": "test {{value}}"}],
479+
variables={"value": "test", "context": "test", "question": "test"},
480+
tags={"env": "prod", "llm": "openai"},
481+
rag_context_variables=["context"],
482+
rag_query_variables=["question"],
483+
)
484+
with llmobs.llm() as span:
485+
llmobs.annotate(
486+
span,
487+
prompt=prompt,
488+
)
464489
events = llmobs_backend.wait_for_num_events(num=1)
465490
assert len(events) == 1
466491
assert events[0][0]["spans"][0]["meta"]["input"] == {
467492
"prompt": {
468-
"template": "test {{value}}",
493+
"id": "test",
494+
"chat_template": [{"role": "user", "content": "test {{value}}"}],
495+
"variables": {"value": "test", "context": "test", "question": "test"},
496+
"tags": {"env": "prod", "llm": "openai"},
469497
"_dd_context_variable_keys": ["context"],
470498
"_dd_query_variable_keys": ["question"],
471-
},
499+
}
472500
}
473501

474502

0 commit comments

Comments
 (0)