Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(autofix): Add prompt caching #1650

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ chromadb==0.4.14
google-cloud-storage==2.*
google-cloud-aiplatform==1.*
google-cloud-secret-manager==2.*
anthropic[vertex]==0.34.2
anthropic[vertex]==0.42.0
langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e
watchdog
stumpy==1.13.0
Expand Down
13 changes: 6 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
aiohappyeyeballs==2.4.4
# via aiohttp
aiohttp==3.11.10
aiohttp==3.11.11
# via
# -r requirements-constraints.txt
# datasets
Expand All @@ -21,7 +21,7 @@ amqp==5.3.1
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.34.2
anthropic==0.42.0
# via -r requirements-constraints.txt
anyio==4.7.0
# via
Expand Down Expand Up @@ -200,7 +200,7 @@ google-auth==2.37.0
# google-cloud-secret-manager
# google-cloud-storage
# google-genai
google-cloud-aiplatform==1.74.0
google-cloud-aiplatform==1.75.0
# via -r requirements-constraints.txt
google-cloud-bigquery==3.27.0
# via google-cloud-aiplatform
Expand All @@ -220,7 +220,7 @@ google-crc32c==1.6.0
# via
# google-cloud-storage
# google-resumable-media
google-genai==0.2.2
google-genai==0.3.0
# via -r requirements-constraints.txt
google-resumable-media==2.7.2
# via
Expand Down Expand Up @@ -426,7 +426,7 @@ onnx==1.16.0
# via -r requirements-constraints.txt
onnxruntime==1.20.1
# via chromadb
openai==1.57.4
openai==1.58.1
# via -r requirements-constraints.txt
openapi-core==0.18.2
# via -r requirements-constraints.txt
Expand Down Expand Up @@ -505,7 +505,7 @@ proto-plus==1.25.0
# google-cloud-aiplatform
# google-cloud-resource-manager
# google-cloud-secret-manager
protobuf==5.29.1
protobuf==5.29.2
# via
# -r requirements-constraints.txt
# google-api-core
Expand Down Expand Up @@ -730,7 +730,6 @@ threadpoolctl==3.2.0
# scikit-learn
tokenizers==0.15.2
# via
# anthropic
# chromadb
# transformers
torch==2.2.0
Expand Down
36 changes: 30 additions & 6 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import anthropic
from anthropic import NOT_GIVEN
from anthropic.types import (
CacheControlEphemeralParam,
MessageParam,
TextBlockParam,
ToolParam,
Expand Down Expand Up @@ -415,7 +416,7 @@ def generate_text(
max_tokens: int | None = None,
timeout: float | None = None,
):
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
message_dicts, tool_dicts, system_message = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
Expand All @@ -425,7 +426,7 @@ def generate_text(
anthropic_client = self.get_client()

completion = anthropic_client.messages.create(
system=system_prompt or NOT_GIVEN,
system=system_message or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
messages=cast(Iterable[MessageParam], message_dicts),
Expand Down Expand Up @@ -538,16 +539,39 @@ def _prep_message_and_tools(
prompt: str | None = None,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
) -> tuple[list[MessageParam], list[ToolParam] | None, str | None]:
) -> tuple[list[MessageParam], list[ToolParam] | None, list[TextBlockParam] | None]:
message_dicts = [cls.to_message_param(message) for message in messages] if messages else []
if prompt:
message_dicts.append(cls.to_message_param(Message(role="user", content=prompt)))
# Set caching breakpoints for the last message and the 3rd last message
if len(message_dicts) > 0 and message_dicts[-1].content:
message_dicts[-1].content[0].cache_control = CacheControlEphemeralParam(
type="ephemeral"
)
if len(message_dicts) >= 3 and message_dicts[-3]:
message_dicts[-3].content[0].cache_control = CacheControlEphemeralParam(
type="ephemeral"
)

tool_dicts = (
[cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None
)
if tool_dicts:
tool_dicts[-1].cache_control = CacheControlEphemeralParam(type="ephemeral")

system_message = (
[
TextBlockParam(
type="text",
text=system_prompt,
cache_control=CacheControlEphemeralParam(type="ephemeral"),
)
]
if system_prompt
else []
)

return message_dicts, tool_dicts, system_prompt
return message_dicts, tool_dicts, system_message

@observe(as_type="generation", name="Anthropic Stream")
def generate_text_stream(
Expand All @@ -561,7 +585,7 @@ def generate_text_stream(
max_tokens: int | None = None,
timeout: float | None = None,
) -> Iterator[str | ToolCall | Usage]:
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
message_dicts, tool_dicts, system_message = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
Expand All @@ -571,7 +595,7 @@ def generate_text_stream(
anthropic_client = self.get_client()

stream = anthropic_client.messages.create(
system=system_prompt or NOT_GIVEN,
system=system_message or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
messages=cast(Iterable[MessageParam], message_dicts),
Expand Down
2 changes: 1 addition & 1 deletion tests/automation/agent/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_anthropic_prep_message_and_tools():
assert "description" in tool_dicts[0]
assert "input_schema" in tool_dicts[0]

assert returned_system_prompt == system_prompt
assert returned_system_prompt[0]["text"] == system_prompt


@pytest.mark.vcr()
Expand Down
Loading