|
2 | 2 | from dataclasses import dataclass |
3 | 3 | from dataclasses import is_dataclass |
4 | 4 | import json |
| 5 | +from typing import Any |
5 | 6 | from typing import Dict |
6 | 7 | from typing import List |
7 | 8 | from typing import Optional |
|
14 | 15 | from ddtrace.internal.logger import get_logger |
15 | 16 | from ddtrace.internal.utils.formats import format_trace_id |
16 | 17 | from ddtrace.llmobs._constants import CREWAI_APM_SPAN_NAME |
| 18 | +from ddtrace.llmobs._constants import DEFAULT_PROMPT_NAME |
17 | 19 | from ddtrace.llmobs._constants import GEMINI_APM_SPAN_NAME |
18 | 20 | from ddtrace.llmobs._constants import INTERNAL_CONTEXT_VARIABLE_KEYS |
19 | 21 | from ddtrace.llmobs._constants import INTERNAL_QUERY_VARIABLE_KEYS |
|
27 | 29 | from ddtrace.llmobs._constants import SESSION_ID |
28 | 30 | from ddtrace.llmobs._constants import SPAN_LINKS |
29 | 31 | from ddtrace.llmobs._constants import VERTEXAI_APM_SPAN_NAME |
| 32 | +from ddtrace.llmobs.utils import Message |
| 33 | +from ddtrace.llmobs.utils import Prompt |
30 | 34 | from ddtrace.trace import Span |
31 | 35 |
|
32 | 36 |
|
33 | 37 | log = get_logger(__name__) |
34 | 38 |
|
| 39 | +ValidatedPromptDict = Dict[str, Union[str, Dict[str, Any], List[str], List[Dict[str, str]], List[Message]]] |
35 | 40 |
|
36 | 41 | STANDARD_INTEGRATION_SPAN_NAMES = ( |
37 | 42 | CREWAI_APM_SPAN_NAME, |
|
42 | 47 | ) |
43 | 48 |
|
44 | 49 |
|
45 | | -def validate_prompt(prompt: dict) -> Dict[str, Union[str, dict, List[str]]]: |
46 | | - validated_prompt = {} # type: Dict[str, Union[str, dict, List[str]]] |
| 50 | +def _validate_prompt(prompt: Union[Dict[str, Any], Prompt], strict_validation: bool) -> ValidatedPromptDict: |
47 | 51 | if not isinstance(prompt, dict): |
48 | | - raise TypeError("Prompt must be a dictionary") |
| 52 | + raise TypeError(f"Prompt must be a dictionary, received {type(prompt).__name__}.") |
| 53 | + |
| 54 | + ml_app = config._llmobs_ml_app |
| 55 | + prompt_id = prompt.get("id") |
| 56 | + version = prompt.get("version") |
| 57 | + tags = prompt.get("tags") |
49 | 58 | variables = prompt.get("variables") |
50 | 59 | template = prompt.get("template") |
51 | | - version = prompt.get("version") |
52 | | - prompt_id = prompt.get("id") |
| 60 | + chat_template = prompt.get("chat_template") |
53 | 61 | ctx_variable_keys = prompt.get("rag_context_variables") |
54 | | - rag_query_variable_keys = prompt.get("rag_query_variables") |
55 | | - if variables is not None: |
| 62 | + query_variable_keys = prompt.get("rag_query_variables") |
| 63 | + |
| 64 | + if strict_validation: |
| 65 | + if prompt_id is None: |
| 66 | + raise ValueError("'id' must be provided") |
| 67 | + if template is None and chat_template is None: |
| 68 | + raise ValueError("One of 'template' or 'chat_template' must be provided to annotate a prompt.") |
| 69 | + |
| 70 | + if template and chat_template: |
| 71 | + raise ValueError("Only one of 'template' or 'chat_template' can be provided, not both.") |
| 72 | + |
| 73 | + final_prompt_id = prompt_id or f"{ml_app}_{DEFAULT_PROMPT_NAME}" |
| 74 | + final_ctx_variable_keys = ctx_variable_keys or ["context"] |
| 75 | + final_query_variable_keys = query_variable_keys or ["question"] |
| 76 | + |
| 77 | + if not isinstance(final_prompt_id, str): |
| 78 | + raise TypeError(f"prompt_id {final_prompt_id} must be a string, received {type(final_prompt_id).__name__}") |
| 79 | + |
| 80 | + if not (isinstance(final_ctx_variable_keys, list) and all(isinstance(i, str) for i in final_ctx_variable_keys)): |
| 81 | + raise TypeError(f"ctx_variables must be a list of strings, received {type(final_ctx_variable_keys).__name__}") |
| 82 | + |
| 83 | + if not (isinstance(final_query_variable_keys, list) and all(isinstance(i, str) for i in final_query_variable_keys)): |
| 84 | + raise TypeError( |
| 85 | + f"query_variables must be a list of strings, received {type(final_query_variable_keys).__name__}" |
| 86 | + ) |
| 87 | + |
| 88 | + if version and not isinstance(version, str): |
| 89 | + raise TypeError(f"version: {version} must be a string, received {type(version).__name__}") |
| 90 | + |
| 91 | + if tags: |
| 92 | + if not isinstance(tags, dict): |
| 93 | + raise TypeError( |
| 94 | + f"tags: {tags} must be a dictionary of string key-value pairs, received {type(tags).__name__}" |
| 95 | + ) |
| 96 | + if not all(isinstance(k, str) for k in tags): |
| 97 | + raise TypeError("Keys of 'tags' must all be strings.") |
| 98 | + if not all(isinstance(k, str) for k in tags.values()): |
| 99 | + raise TypeError("Values of 'tags' must all be strings.") |
| 100 | + |
| 101 | + if template and not isinstance(template, str): |
| 102 | + raise TypeError(f"template: {template} must be a string, received {type(template).__name__}") |
| 103 | + |
| 104 | + if chat_template: |
| 105 | + if not isinstance(chat_template, list): |
| 106 | + raise TypeError("chat_template must be a list of dictionaries with string-string key value pairs.") |
| 107 | + for ct in chat_template: |
| 108 | + if not (isinstance(ct, dict) and all(k in ct for k in ("role", "content"))): |
| 109 | + raise TypeError( |
| 110 | + "Each 'chat_template' entry should be a string-string dictionary with role and content keys." |
| 111 | + ) |
| 112 | + |
| 113 | + if variables: |
56 | 114 | if not isinstance(variables, dict): |
57 | | - raise TypeError("Prompt variables must be a dictionary.") |
58 | | - if not any(isinstance(k, str) or isinstance(v, str) for k, v in variables.items()): |
59 | | - raise TypeError("Prompt variable keys and values must be strings.") |
| 115 | + raise TypeError( |
| 116 | + f"variables: {variables} must be a dictionary with string keys, received {type(variables).__name__}" |
| 117 | + ) |
| 118 | + if not all(isinstance(k, str) for k in variables): |
| 119 | + raise TypeError("Keys of 'variables' must all be strings.") |
| 120 | + |
| 121 | + final_chat_template = [] |
| 122 | + if chat_template: |
| 123 | + for msg in chat_template: |
| 124 | + final_chat_template.append(Message(role=msg["role"], content=msg["content"])) |
| 125 | + |
| 126 | + validated_prompt: ValidatedPromptDict = {} |
| 127 | + if final_prompt_id: |
| 128 | + validated_prompt["id"] = final_prompt_id |
| 129 | + if version: |
| 130 | + validated_prompt["version"] = version |
| 131 | + if variables: |
60 | 132 | validated_prompt["variables"] = variables |
61 | | - if template is not None: |
62 | | - if not isinstance(template, str): |
63 | | - raise TypeError("Prompt template must be a string") |
| 133 | + if template: |
64 | 134 | validated_prompt["template"] = template |
65 | | - if version is not None: |
66 | | - if not isinstance(version, str): |
67 | | - raise TypeError("Prompt version must be a string.") |
68 | | - validated_prompt["version"] = version |
69 | | - if prompt_id is not None: |
70 | | - if not isinstance(prompt_id, str): |
71 | | - raise TypeError("Prompt id must be a string.") |
72 | | - validated_prompt["id"] = prompt_id |
73 | | - if ctx_variable_keys is not None: |
74 | | - if not isinstance(ctx_variable_keys, list): |
75 | | - raise TypeError("Prompt field `context_variable_keys` must be a list of strings.") |
76 | | - if not all(isinstance(k, str) for k in ctx_variable_keys): |
77 | | - raise TypeError("Prompt field `context_variable_keys` must be a list of strings.") |
78 | | - validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ctx_variable_keys |
79 | | - else: |
80 | | - validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ["context"] |
81 | | - if rag_query_variable_keys is not None: |
82 | | - if not isinstance(rag_query_variable_keys, list): |
83 | | - raise TypeError("Prompt field `rag_query_variables` must be a list of strings.") |
84 | | - if not all(isinstance(k, str) for k in rag_query_variable_keys): |
85 | | - raise TypeError("Prompt field `rag_query_variables` must be a list of strings.") |
86 | | - validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = rag_query_variable_keys |
87 | | - else: |
88 | | - validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = ["question"] |
| 135 | + if final_chat_template: |
| 136 | + validated_prompt["chat_template"] = final_chat_template |
| 137 | + if tags: |
| 138 | + validated_prompt["tags"] = tags |
| 139 | + if final_ctx_variable_keys: |
| 140 | + validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = final_ctx_variable_keys |
| 141 | + if final_query_variable_keys: |
| 142 | + validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = final_query_variable_keys |
| 143 | + |
89 | 144 | return validated_prompt |
90 | 145 |
|
91 | 146 |
|
|
0 commit comments